]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert.py : fix vanilla LLaMA model conversion (#4818)
authorAustin <redacted>
Tue, 9 Jan 2024 18:46:46 +0000 (13:46 -0500)
committerGitHub <redacted>
Tue, 9 Jan 2024 18:46:46 +0000 (20:46 +0200)
* Update Imports and Add Notes for Future Reference

- Updated import statements in `convert.py`.
- Added import for `AutoTokenizer` from `transformers` module.
- Added conditional import for `gguf` from the local directory.
- Added comments and notes for future reference.

Additional Notes:

- Noted removal of a redundant `TypeAlias` import.
- Noted the removal of a `gguf` debug statement.
- Commented on the presence of `ARCH` and `NDArray` definitions.
- Commented on cleaning up and refactoring data type definitions.

* Refine Model Hyperparameters and Params Class

- Updated type annotations to use `Optional` for clarity.
- Improved method names and attribute consistency.
- Removed unnecessary variables for better code readability.

Additional Notes:

- Highlighted the use of `Optional` for clearer intent.
- Ensured backward and forward compatibility.

* Restore BpeVocab and SentencePieceVocab classes

- Restored the BpeVocab class for handling BPE tokenization.
- Restored the SentencePieceVocab class for SentencePiece tokenization.

These classes are essential for maintaining the original behavior of the codebase.

* refactor: Standardize vocabulary handling with HfVocab

- Replaced VocabLoader with HfVocab, aligning vocabulary handling across classes.
- Updated initialization of HfVocab with local_files_only=True for AutoTokenizer.
- Introduced optional parameter fname_added_tokens for flexible added token management.
- Streamlined added token handling for clarity and conciseness.
- Maintained special tokens and IDs, enhancing token management.
- Simplified token processing methods for improved readability.
- Added a placeholder for score computation with a default value of -1000.0.
- Optimized newline token check for efficiency.
- Updated __repr__ function for clarity in representation.
- Adjusted type alias Vocab to include BpeVocab, SentencePieceVocab, and HfVocab.
- Removed redundant code related to special token handling, reverse vocabulary mapping, and vocabulary file detection.

This refactoring promotes a standardized and modular approach to vocabulary management, facilitating future integration with a VocabFactory and improving code maintainability and scalability.

* refactor: Enhance readability, functionality, and code quality

- Improved code formatting and readability for better maintainability.
- Refactored LazyUnpickler's CLASSES dictionary for clarity.
- Added print statements and warnings in check_vocab_size for user feedback.
- Removed find_vocab_file_path, as it's superseded by VocabFactory.
- Preparatory changes for upcoming classes: OutputFile and VocabFactory.
- Overall focus on code quality, error handling, and consistency.

These changes reflect a continuous effort to refine the codebase, ensuring it meets best practices and prepares for future enhancements, such as the VocabFactory.

* refactor: Update OutputFile class for enhanced model vocabulary management

- Restructured the constructor for improved readability.
- Updated `add_meta_arch` method for flexible model name determination.
- Introduced `handle_tokenizer_model` for mapping vocab types to supported tokenizer models.
- Streamlined vocabulary extraction with `extract_vocabulary_from_model`.
- Simplified vocabulary metadata addition using `add_meta_vocab`.
- Refactored `add_tensor_info` for clarity and consistency.
- Improved error handling for better user feedback.

These changes signify the development of a versatile and comprehensive `OutputFile` class, enabling efficient management of model conversion output, metadata, vocabulary, and tensor information.

* feat: Introduce VocabFactory for flexible vocabulary management in model conversion

- The VocabFactory class is added to facilitate modular vocabulary handling.
- The constructor initializes a directory path and detects vocabulary-related files.
- The _select_file method provides file paths based on vocabulary type (e.g., BPE, SentencePiece).
- _create_special_vocab generates special vocabularies, accommodating different types.
- The load_vocab method loads vocabularies, handling BPE, SentencePiece, and Hugging Face Fast Tokenizer.
- Error handling and logging enhance debugging and user feedback.
- The modular and flexible design simplifies vocabulary management and supports future extensions.

The VocabFactory class enhances code modularity and maintainability, allowing versatile vocabulary handling in the model conversion process.

* refactor: Improve code organization, argument parsing, and user interface

- Renamed 'default_outfile' to 'default_output_file' for clarity.
- Refactored argument parser setup into 'get_argument_parser' function.
- Introduced descriptive comments for each argument in the parser.
- Added '--vocab-type' argument with choices ["spm", "bpe", "hfft"] for vocabulary processing.
- Improved flag naming consistency: '--outfile' to '--out-file' and '--bigendian' to '--big-endian'.
- Enhanced error handling to prevent overwriting input data in 'default_output_file'.
- Made 'argv' in 'main' an optional parameter for flexibility.
- Introduced dynamic import for 'awq.apply_awq' based on 'args.awq_path' for conditional dependency.

These changes enhance code clarity, organization, and the user interface of the script, aligning it with Python best practices and improving maintainability.

* refactor: Further refine functionality, improve user interaction, and streamline vocabulary handling

- Renamed command-line arguments for clarity and consistency.
- Improved path resolution and import adjustments for robustness.
- Thoughtfully handled 'awq-path' and conditional logic for the weighted model.
- Enhanced model and vocabulary loading with the 'VocabFactory' class for structured and adaptable loading.
- Strengthened error handling and user feedback for a more user-friendly experience.
- Structured output file handling with clear conditions and defaults.
- Streamlined and organized the 'main' function for better logic flow.
- Passed 'sys.argv[1:]' to 'main' for adaptability and testability.

These changes solidify the script's functionality, making it more robust, user-friendly, and adaptable. The use of the 'VocabFactory' class is a notable enhancement in efficient vocabulary handling, reflecting a thoughtful and iterative approach to script development.

* chore: Apply ruff formatting to convert.py

Signed-off-by: teleprint-me <redacted>
* Revert to commit 0614c33

* chore: Apply flake8 formatting rules

Signed-off-by: teleprint-me <redacted>
* refactor: Revise `check_vocab_size` for Enhanced Clarity and Correctness

- Resolved an unreachable branch issue by reorganizing the conditional structure.
- Moved the special case check for `params.n_vocab == -1` to the top for immediate assertion.
- Flattened the conditional logic for improved clarity and predictability of the function's behavior.

These changes enhance the readability and functional correctness of the `check_vocab_size` function without altering its intended functionality.

* py : fix outfile and outtype

* py : suggest hint for missing vocab size

---------

Signed-off-by: teleprint-me <redacted>
Co-authored-by: Georgi Gerganov <redacted>
convert.py

index c3f3fc0a1fcd39ba1076574ea03192b17af125aa..3b613eefc6c2c04dd514ec74e880a582421ad86e 100755 (executable)
@@ -17,29 +17,58 @@ import signal
 import struct
 import sys
 import time
+import warnings
 import zipfile
 from abc import ABCMeta, abstractmethod
-from collections import OrderedDict
+from argparse import ArgumentParser
 from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
 from dataclasses import dataclass
 from pathlib import Path
-from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, Optional, TypeVar, cast
+from typing import (
+    IO,
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Iterable,
+    Literal,
+    Optional,
+    Tuple,
+    TypeVar,
+)
 
 import numpy as np
 from sentencepiece import SentencePieceProcessor
 
-if 'NO_LOCAL_GGUF' not in os.environ:
-    sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
-import gguf
-
-if TYPE_CHECKING:
-    from typing import TypeAlias
-
-if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
+try:
+    from transformers import AutoTokenizer
+except ModuleNotFoundError as e:
+    warnings.warn(f"Could not import AutoTokenizer from transformers: {e}")
+
+# If NO_LOCAL_GGUF is not set, try to import gguf from the local gguf-py directory
+if "NO_LOCAL_GGUF" not in os.environ:
+    # Use absolute path to the gguf-py directory
+    gguf_py_dir = str(Path(__file__).resolve().parent / "gguf-py")
+    print(gguf_py_dir)  # NOTE: Remove this once path is verified after changes are completed
+    if gguf_py_dir not in sys.path:
+        sys.path.insert(1, gguf_py_dir)
+
+# Import gguf module
+try:
+    import gguf
+except ModuleNotFoundError as e:
+    print(f"Could not import gguf: {e}")
+    sys.exit(1)
+
+if TYPE_CHECKING:  # NOTE: This isn't necessary.
+    from typing import TypeAlias  # This can technically be omitted.
+
+if hasattr(faulthandler, "register") and hasattr(signal, "SIGUSR1"):
     faulthandler.register(signal.SIGUSR1)
 
-NDArray: TypeAlias = 'np.ndarray[Any, Any]'
+# NOTE: n-dimensional arrays should be directly referenced
+NDArray: TypeAlias = "np.ndarray[Any, Any]"
 
+# Why is this here? LLAMA and GPT are technically the only compatible ARCHs.
 ARCH = gguf.MODEL_ARCH.LLAMA
 
 DEFAULT_CONCURRENCY = 8
@@ -49,6 +78,7 @@ DEFAULT_CONCURRENCY = 8
 #
 
 
+# TODO: Clean up and refactor data types
 @dataclass(frozen=True)
 class DataType:
     name: str
@@ -153,65 +183,85 @@ GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = {
 
 @dataclass
 class Params:
-    n_vocab:        int
-    n_embd:         int
-    n_layer:        int
-    n_ctx:          int
-    n_ff:           int
-    n_head:         int
-    n_head_kv:      int
-    n_experts:      int | None = None
-    n_experts_used: int | None = None
-    f_norm_eps:     float | None = None
-
-    rope_scaling_type: gguf.RopeScalingType | None = None
-    f_rope_freq_base: float | None = None
-    f_rope_scale: float | None = None
-    n_orig_ctx: int | None = None
-    rope_finetuned: bool | None = None
-
-    ftype: GGMLFileType | None = None
+    n_vocab: int
+    n_embd: int
+    n_layer: int
+    n_ctx: int
+    n_ff: int
+    n_head: int
+    n_head_kv: int
+    f_norm_eps: Optional[float] = None
+    n_experts: Optional[int] = None
+    n_experts_used: Optional[int] = None
+
+    rope_scaling_type: Optional[gguf.RopeScalingType] = None
+    f_rope_freq_base: Optional[float] = None
+    f_rope_scale: Optional[float] = None
+    n_orig_ctx: Optional[int] = None
+    rope_finetuned: Optional[bool] = None
+
+    ftype: Optional[GGMLFileType] = None
 
     # path to the directory containing the model files
-    path_model: Path | None = None
+    path_model: Optional[Path] = None
 
     @staticmethod
-    def guessed(model: LazyModel) -> Params:
+    def guessed(model: LazyModel) -> "Params":
         # try transformer naming first
-        n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape
+        n_vocab, n_embd = (
+            model["model.embed_tokens.weight"].shape
+            if "model.embed_tokens.weight" in model
+            else model["tok_embeddings.weight"].shape
+        )
 
         # try transformer naming first
         if "model.layers.0.self_attn.q_proj.weight" in model:
-            n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.q_proj.weight" not in model)
-        elif "model.layers.0.self_attn.W_pack.weight" in model:   # next: try baichuan naming
-            n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.W_pack.weight" not in model)
+            n_layer = next(
+                i
+                for i in itertools.count()
+                if f"model.layers.{i}.self_attn.q_proj.weight" not in model
+            )
+        elif (
+            "model.layers.0.self_attn.W_pack.weight" in model
+        ):  # next: try baichuan naming
+            n_layer = next(
+                i
+                for i in itertools.count()
+                if f"model.layers.{i}.self_attn.W_pack.weight" not in model
+            )
         else:
-            n_layer = next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model)
+            n_layer = next(
+                i
+                for i in itertools.count()
+                if f"layers.{i}.attention.wq.weight" not in model
+            )
 
         if n_layer < 1:
-            raise Exception("failed to guess 'n_layer'. This model is unknown or unsupported.\n"
-                            "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
+            raise Exception(
+                "failed to guess 'n_layer'. This model is unknown or unsupported.\n"
+                "Suggestion: provide 'config.json' of the model in the same directory containing model files."
+            )
 
-        n_head = n_embd // 128 # guessed
-        n_mult = 256           # guessed
+        n_head = n_embd // 128  # guessed
+        n_mult = 256  # guessed
 
         # TODO: verify this
         n_ff = int(2 * (4 * n_embd) / 3)
         n_ff = n_mult * ((n_ff + n_mult - 1) // n_mult)
 
         return Params(
-            n_vocab    = n_vocab,
-            n_embd     = n_embd,
-            n_layer    = n_layer,
-            n_ctx      = -1,
-            n_ff       = n_ff,
-            n_head     = n_head,
-            n_head_kv  = n_head,
-            f_norm_eps = 1e-5,
+            n_vocab=n_vocab,
+            n_embd=n_embd,
+            n_layer=n_layer,
+            n_ctx=-1,
+            n_ff=n_ff,
+            n_head=n_head,
+            n_head_kv=n_head,
+            f_norm_eps=1e-5,
         )
 
     @staticmethod
-    def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params:
+    def load_transformers_config(model: LazyModel, config_path: Path) -> "Params":
         config = json.load(open(config_path))
 
         rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None
@@ -224,20 +274,22 @@ class Params:
                 rope_scaling_type = gguf.RopeScalingType.LINEAR
             elif typ == "yarn":
                 rope_scaling_type = gguf.RopeScalingType.YARN
-                n_orig_ctx = rope_scaling['original_max_position_embeddings']
-                rope_finetuned = rope_scaling['finetuned']
+                n_orig_ctx = rope_scaling["original_max_position_embeddings"]
+                rope_finetuned = rope_scaling["finetuned"]
             else:
-                raise NotImplementedError(f'Unknown rope scaling type: {typ}')
+                raise NotImplementedError(f"Unknown rope scaling type: {typ}")
 
         if "max_sequence_length" in config:
             n_ctx = config["max_sequence_length"]
         elif "max_position_embeddings" in config:
             n_ctx = config["max_position_embeddings"]
         else:
-            raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n"
-                            "Suggestion: provide 'config.json' of the model in the same directory containing model files.")
+            raise Exception(
+                "failed to guess 'n_ctx'. This model is unknown or unsupported.\n"
+                "Suggestion: provide 'config.json' of the model in the same directory containing model files."
+            )
 
-        n_experts      = None
+        n_experts = None
         n_experts_used = None
 
         if "num_local_experts" in config:
@@ -245,30 +297,30 @@ class Params:
             n_experts_used = config["num_experts_per_tok"]
 
         return Params(
-            n_vocab           = config["vocab_size"],
-            n_embd            = config["hidden_size"],
-            n_layer           = config["num_hidden_layers"],
-            n_ctx             = n_ctx,
-            n_ff              = config["intermediate_size"],
-            n_head            = (n_head := config["num_attention_heads"]),
-            n_head_kv         = config.get("num_key_value_heads", n_head),
-            n_experts         = n_experts,
-            n_experts_used    = n_experts_used,
-            f_norm_eps        = config["rms_norm_eps"],
-            f_rope_freq_base  = config.get("rope_theta"),
-            rope_scaling_type = rope_scaling_type,
-            f_rope_scale      = f_rope_scale,
-            n_orig_ctx        = n_orig_ctx,
-            rope_finetuned    = rope_finetuned,
+            n_vocab=config["vocab_size"],
+            n_embd=config["hidden_size"],
+            n_layer=config["num_hidden_layers"],
+            n_ctx=n_ctx,
+            n_ff=config["intermediate_size"],
+            n_head=(n_head := config["num_attention_heads"]),
+            n_head_kv=config.get("num_key_value_heads", n_head),
+            n_experts=n_experts,
+            n_experts_used=n_experts_used,
+            f_norm_eps=config["rms_norm_eps"],
+            f_rope_freq_base=config.get("rope_theta"),
+            rope_scaling_type=rope_scaling_type,
+            f_rope_scale=f_rope_scale,
+            n_orig_ctx=n_orig_ctx,
+            rope_finetuned=rope_finetuned,
         )
 
     # LLaMA v2 70B params.json
     # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1}
     @staticmethod
-    def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
+    def load_torch_params(model: LazyModel, config_path: Path) -> "Params":
         config = json.load(open(config_path))
 
-        n_experts      = None
+        n_experts = None
         n_experts_used = None
         f_rope_freq_base = None
 
@@ -291,129 +343,249 @@ class Params:
 
         if config.get("moe"):
             n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0]
-            n_experts      = config["moe"]["num_experts"]
+            n_experts = config["moe"]["num_experts"]
             n_experts_used = config["moe"]["num_experts_per_tok"]
             f_rope_freq_base = 1e6
 
         return Params(
-            n_vocab          = model["tok_embeddings.weight"].shape[0],
-            n_embd           = config["dim"],
-            n_layer          = config["n_layers"],
-            n_ctx            = n_ctx,
-            n_ff             = n_ff,
-            n_head           = (n_head := config["n_heads"]),
-            n_head_kv        = config.get("n_kv_heads", n_head),
-            n_experts        = n_experts,
-            n_experts_used   = n_experts_used,
-            f_norm_eps       = config["norm_eps"],
-            f_rope_freq_base = config.get("rope_theta", f_rope_freq_base),
+            n_vocab=config.get("vocab_size", model["tok_embeddings.weight"].shape[0]),
+            n_embd=config["dim"],
+            n_layer=config["n_layers"],
+            n_ctx=n_ctx,
+            n_ff=n_ff,
+            n_head=(n_head := config["n_heads"]),
+            n_head_kv=config.get("n_kv_heads", n_head),
+            n_experts=n_experts,
+            n_experts_used=n_experts_used,
+            f_norm_eps=config["norm_eps"],
+            f_rope_freq_base=config.get("rope_theta", f_rope_freq_base),
         )
 
     @staticmethod
-    def load(model_plus: ModelPlus) -> Params:
-        hf_config_path   = model_plus.paths[0].parent / "config.json"
+    def load(model_plus: ModelPlus) -> "Params":
+        hf_config_path = model_plus.paths[0].parent / "config.json"
         orig_config_path = model_plus.paths[0].parent / "params.json"
 
         if hf_config_path.exists():
-            params = Params.loadHFTransformerJson(model_plus.model, hf_config_path)
+            params = Params.load_transformers_config(model_plus.model, hf_config_path)
         elif orig_config_path.exists():
-            params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path)
-        elif model_plus.format != 'none':
+            params = Params.load_torch_params(model_plus.model, orig_config_path)
+        elif model_plus.format != "none":
             params = Params.guessed(model_plus.model)
         else:
-            raise ValueError('Cannot guess params when model format is none')
+            raise ValueError("Cannot guess params when model format is none")
 
         params.path_model = model_plus.paths[0].parent
 
         return params
 
 
-class VocabLoader:
-    def __init__(self, params: Params, fname_tokenizer: Path) -> None:
-        try:
-            from transformers import AutoTokenizer
-        except ImportError as e:
-            raise ImportError(
-                "To use VocabLoader, please install the `transformers` package. "
-                "You can install it with `pip install transformers`."
-            ) from e
+class BpeVocab:  # GPT
+    def __init__(
+        self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]
+    ) -> None:
+        self.bpe_tokenizer = json.loads(
+            open(str(fname_tokenizer), encoding="utf-8").read()
+        )
+        added_tokens: dict[str, int]
+        if fname_added_tokens is not None:
+            # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
+            added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
+        else:
+            # Fall back to trying to find the added tokens in tokenizer.json
+            tokenizer_json_file = fname_tokenizer.parent / "tokenizer.json"
+            if not tokenizer_json_file.is_file():
+                added_tokens = {}
+            else:
+                tokenizer_json = json.load(open(tokenizer_json_file, encoding="utf-8"))
+                added_tokens = dict(
+                    (item["content"], item["id"])
+                    for item in tokenizer_json.get("added_tokens", [])
+                    # Added tokens here can be duplicates of the main vocabulary.
+                    if item["content"] not in self.bpe_tokenizer
+                )
+
+        vocab_size: int = len(self.bpe_tokenizer)
+        expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
+        actual_ids = sorted(added_tokens.values())
+        if expected_ids != actual_ids:
+            expected_end_id = vocab_size + len(actual_ids) - 1
+            raise Exception(
+                f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}"
+            )
+
+        items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
+        self.added_tokens_list = [text for (text, idx) in items]
+        self.vocab_size_base: int = vocab_size
+        self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
+        self.fname_tokenizer = fname_tokenizer
+        self.fname_added_tokens = fname_added_tokens
+
+    def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+        tokenizer = self.bpe_tokenizer
+        reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.items()}
+
+        for i, _ in enumerate(tokenizer):
+            yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL
 
-        try:
-            self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), trust_remote_code=True)
-        except ValueError:
-            self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), use_fast=False, trust_remote_code=True)
+    def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+        for text in self.added_tokens_list:
+            score = -1000.0
+            yield text.encode("utf-8"), score, gguf.TokenType.CONTROL
+
+    def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+        yield from self.bpe_tokens()
+        yield from self.added_tokens()
 
-        self.added_tokens_dict: OrderedDict[str, int] = OrderedDict()
+    def __repr__(self) -> str:
+        return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
 
-        for tok, tokidx in sorted(self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]):
-            if tokidx >= params.n_vocab or tokidx < self.tokenizer.vocab_size:
-                continue
 
-            self.added_tokens_dict[tok] = tokidx
+class SentencePieceVocab:  # LlaMa
+    def __init__(
+        self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]
+    ) -> None:
+        self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
+        added_tokens: dict[str, int]
+        if fname_added_tokens is not None:
+            added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
+        else:
+            added_tokens = {}
+
+        vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
+
+        new_tokens = {
+            id: piece for piece, id in added_tokens.items() if id >= vocab_size
+        }
+        expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
+        actual_new_ids = sorted(new_tokens.keys())
+
+        if expected_new_ids != actual_new_ids:
+            raise ValueError(
+                f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}"
+            )
+
+        # Token pieces that were added to the base vocabulary.
+        self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
+        self.vocab_size_base = vocab_size
+        self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
+        self.fname_tokenizer = fname_tokenizer
+        self.fname_added_tokens = fname_added_tokens
+
+    def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+        tokenizer = self.sentencepiece_tokenizer
+        for i in range(tokenizer.vocab_size()):
+            piece = tokenizer.id_to_piece(i)
+            text: bytes = piece.encode("utf-8")
+            score: float = tokenizer.get_score(i)
+
+            toktype = gguf.TokenType.NORMAL
+            if tokenizer.is_unknown(i):
+                toktype = gguf.TokenType.UNKNOWN
+            if tokenizer.is_control(i):
+                toktype = gguf.TokenType.CONTROL
+
+            # NOTE: I think added_tokens are user defined.
+            # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
+            # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
+
+            if tokenizer.is_unused(i):
+                toktype = gguf.TokenType.UNUSED
+            if tokenizer.is_byte(i):
+                toktype = gguf.TokenType.BYTE
+
+            yield text, score, toktype
 
-        self.unk_token_id: int = self.tokenizer.unk_token_id
-        self.specials: dict[str, int] = {
+    def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+        for text in self.added_tokens_list:
+            score = -1000.0
+            yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
+
+    def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+        yield from self.sentencepiece_tokens()
+        yield from self.added_tokens()
+
+    def __repr__(self) -> str:
+        return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
+
+
+class HfVocab:
+    def __init__(
+        self,
+        fname_tokenizer: Path,
+        fname_added_tokens: Optional[Path] = None,
+    ) -> None:
+        print("fname_tokenizer:", fname_tokenizer)
+        # Allow the tokenizer to default to slow or fast versions.
+        # Explicitly set tokenizer to use local paths.
+        self.tokenizer = AutoTokenizer.from_pretrained(
+            fname_tokenizer,
+            cache_dir=fname_tokenizer,
+            local_files_only=True,
+        )
+
+        # Initialize lists and dictionaries for added tokens
+        self.added_tokens_list = []
+        self.added_tokens_dict = dict()
+        self.added_tokens_ids = set()
+
+        # Process added tokens
+        for tok, tokidx in sorted(
+            self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
+        ):
+            # Only consider added tokens that are not in the base vocabulary
+            if tokidx >= self.tokenizer.vocab_size:
+                self.added_tokens_list.append(tok)
+                self.added_tokens_dict[tok] = tokidx
+                self.added_tokens_ids.add(tokidx)
+
+        # Store special tokens and their IDs
+        self.specials = {
             tok: self.tokenizer.get_vocab()[tok]
             for tok in self.tokenizer.all_special_tokens
         }
-        self.special_ids: set[int] = set(self.tokenizer.all_special_ids)
-        self.reverse_vocab = {id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()}
-        self.vocab_size_base: int = self.tokenizer.vocab_size
-        self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_dict)
-        self.fname_tokenizer: Path = fname_tokenizer
-
-        vocab_file = "tokenizer.model"
-        path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
-        if path_candidate is not None:
-            self.spm = SentencePieceProcessor(str(path_candidate))
-            print(self.spm.vocab_size(), self.vocab_size_base)
-        else:
-            self.spm = None
+        self.special_ids = set(self.tokenizer.all_special_ids)
 
-    def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
-        added_tokens_ids = set(self.added_tokens_dict.values())
+        # Set vocabulary sizes
+        self.vocab_size_base = self.tokenizer.vocab_size
+        self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
 
-        for i in range(self.vocab_size_base):
-            if i in added_tokens_ids:
-                continue
+        self.fname_tokenizer = fname_tokenizer
+        self.fname_added_tokens = fname_added_tokens
 
-            text = self.reverse_vocab[i].encode("utf-8")
-            yield text, self.get_token_score(i), self.get_token_type(i)
+    def hf_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]:
+        reverse_vocab = {
+            id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
+        }
 
-    def get_token_type(self, token_id: int) -> gguf.TokenType:
-        toktype = gguf.TokenType.NORMAL
+        for token_id in range(self.vocab_size_base):
+            # Skip processing added tokens here
+            if token_id in self.added_tokens_ids:
+                continue
 
-        if self.spm is not None and token_id < self.spm.vocab_size():
-            if self.spm.is_unknown(token_id):
-                toktype = gguf.TokenType.UNKNOWN
-            if self.spm.is_control(token_id):
-                toktype = gguf.TokenType.CONTROL
-            if self.spm.is_unused(token_id):
-                toktype = gguf.TokenType.UNUSED
-            if self.spm.is_byte(token_id):
-                toktype = gguf.TokenType.BYTE
-        else:
-            token = self.reverse_vocab[token_id]
-            if token_id == self.unk_token_id:
-                toktype = gguf.TokenType.UNKNOWN
-            elif token_id in self.special_ids:
-                toktype = gguf.TokenType.CONTROL
-            elif len(token) == 6 and token.startswith("<0x") and token.endswith(">"):
-                toktype = gguf.TokenType.BYTE
+            # Convert token text to bytes
+            token_text = reverse_vocab[token_id].encode("utf-8")
+
+            # Yield token text, score, and type
+            yield token_text, self.get_token_score(token_id), self.get_token_type(
+                token_id, self.special_ids  # Reuse already stored special IDs
+            )
 
-        return toktype
+    def get_token_type(self, token_id: int, special_ids: set) -> gguf.TokenType:
+        # Determine token type based on whether it's a special token
+        return (
+            gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL
+        )
 
     def get_token_score(self, token_id: int) -> float:
-        if self.spm is not None and token_id < self.spm.vocab_size():
-            return cast(float, self.spm.get_score(token_id))
-        return 0.0
+        # Placeholder for actual logic to determine the token's score
+        # This needs to be implemented based on specific requirements
+        return -1000.0  # Default score
 
     def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
-
-        for text in self.added_tokens_dict:
+        for text in self.added_tokens_list:
             if text in self.specials:
-
-                toktype = self.get_token_type(self.specials[text])
+                toktype = self.get_token_type(self.specials[text], self.special_ids)
                 score = self.get_token_score(self.specials[text])
 
             else:
@@ -422,45 +594,18 @@ class VocabLoader:
 
             yield text.encode("utf-8"), score, toktype
 
-    def has_newline_token(self) -> bool:
-        return '<0x0A>' in self.tokenizer.vocab or '\n' in self.tokenizer.vocab
+    def has_newline_token(self):
+        return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
 
     def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
         yield from self.hf_tokens()
         yield from self.added_tokens()
 
-    def get_vocab_type(self) -> str:
-        path_candidates = []
-        vocab_file = "tokenizer.model"
-        path_candidates.append(vocab_file)
-        path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
-        if path_candidate is not None:
-            return "llama"
-
-        vocab_file = "vocab.json"
-        path_candidates.append(vocab_file)
-        path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
-        if path_candidate is not None:
-            return "gpt2"
-
-        vocab_file = "tokenizer.json"
-        path_candidates.append(vocab_file)
-        path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
-        if path_candidate:
-            if not self.has_newline_token():
-                return "gpt2"
-            return "llama"
-
-        raise FileNotFoundError(
-            f"Could not find {path_candidates} in {self.fname_tokenizer} or its parent; "
-            "if it's in another directory, pass the directory as --vocab-dir"
-        )
-
     def __repr__(self) -> str:
-        return f"<VocabLoader with {self.vocab_size_base} base tokens and {len(self.added_tokens_dict)} added tokens>"
+        return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
 
 
-Vocab: TypeAlias = 'VocabLoader'
+Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab"
 
 
 #
@@ -724,13 +869,17 @@ class LazyUnpickler(pickle.Unpickler):
     CLASSES: dict[tuple[str, str], Any] = {
         # getattr used here as a workaround for mypy not being smart enough to determine
         # the staticmethods have a __func__ attribute.
-        ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
-        ('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'),
-        ('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16),
-        ('torch', 'HalfStorage'): LazyStorageKind(DT_F16),
-        ('torch', 'FloatStorage'): LazyStorageKind(DT_F32),
-        ('torch', 'IntStorage'): LazyStorageKind(DT_I32),
-        ('torch', 'Tensor'): LazyTensor,
+        ("torch._tensor", "_rebuild_from_type_v2"): getattr(
+            rebuild_from_type_v2, "__func__"
+        ),
+        ("torch._utils", "_rebuild_tensor_v2"): getattr(
+            lazy_rebuild_tensor_v2, "__func__"
+        ),
+        ("torch", "BFloat16Storage"): LazyStorageKind(DT_BF16),
+        ("torch", "HalfStorage"): LazyStorageKind(DT_F16),
+        ("torch", "FloatStorage"): LazyStorageKind(DT_F32),
+        ("torch", "IntStorage"): LazyStorageKind(DT_I32),
+        ("torch", "Tensor"): LazyTensor,
     }
 
     def find_class(self, module: str, name: str) -> Any:
@@ -839,32 +988,43 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
 
 
 def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None:
-    if params.n_vocab != vocab.vocab_size:
-        if params.n_vocab == vocab.vocab_size:
-            print("Ignoring added_tokens.json since model matches vocab size without it.")
-            vocab.added_tokens_dict = OrderedDict()
-            vocab.vocab_size = vocab.vocab_size
-            return
-
-        if pad_vocab and params.n_vocab > vocab.vocab_size:
-            pad_count = params.n_vocab - vocab.vocab_size
-            print(f'Padding vocab with {pad_count} token(s) - <dummy00001> through <dummy{pad_count:05}>')
-            for i in range(1, (params.n_vocab - vocab.vocab_size) + 1):
-                vocab.added_tokens_dict[f'<dummy{i:05}>'] = -1
-            vocab.vocab_size = params.n_vocab
-            return
-        msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer}"
-        msg += f" has {vocab.vocab_size})."
-        if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20:
-            msg += f"  Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})."
-        if vocab.vocab_size < params.n_vocab:
-            msg += " Possibly try using the --padvocab option."
-        raise Exception(msg)
+    # Handle special case where the model's vocab size is not set
+    if params.n_vocab == -1:
+        raise ValueError(
+            f"The model's vocab size is set to -1 in params.json. Please update it manually. Maybe {vocab.vocab_size}?"
+        )
+
+    # Check for a vocab size mismatch
+    if params.n_vocab == vocab.vocab_size:
+        print("Ignoring added_tokens.json since model matches vocab size without it.")
+        return
+
+    if pad_vocab and params.n_vocab > vocab.vocab_size:
+        pad_count = params.n_vocab - vocab.vocab_size
+        print(
+            f"Padding vocab with {pad_count} token(s) - <dummy00001> through <dummy{pad_count:05}>"
+        )
+        for i in range(1, pad_count + 1):
+            vocab.added_tokens_dict[f"<dummy{i:05}>"] = -1
+        vocab.vocab_size = params.n_vocab
+        return
+
+    msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer} has {vocab.vocab_size})."
+    if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20:
+        msg += f"  Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})."
+    if vocab.vocab_size < params.n_vocab:
+        msg += " Add the --pad-vocab option and try again."
+
+    raise Exception(msg)
 
 
 class OutputFile:
-    def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None:
-        self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
+    def __init__(
+        self, fname_out: Path, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE
+    ) -> None:
+        self.gguf = gguf.GGUFWriter(
+            fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess
+        )
 
     def add_meta_arch(self, params: Params) -> None:
         name = "LLaMA"
@@ -873,16 +1033,21 @@ class OutputFile:
         if params.n_ctx == 4096:
             name = "LLaMA v2"
         elif params.path_model is not None:
-            name = str(params.path_model.parent).split('/')[-1]
+            name = str(params.path_model.parent).split("/")[-1]
 
-        self.gguf.add_name                (name)
-        self.gguf.add_context_length      (params.n_ctx)
-        self.gguf.add_embedding_length    (params.n_embd)
-        self.gguf.add_block_count         (params.n_layer)
-        self.gguf.add_feed_forward_length (params.n_ff)
+        self.gguf.add_name(name)
+        self.gguf.add_context_length(params.n_ctx)
+        self.gguf.add_embedding_length(params.n_embd)
+        self.gguf.add_block_count(params.n_layer)
+        self.gguf.add_feed_forward_length(params.n_ff)
         self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
-        self.gguf.add_head_count          (params.n_head)
-        self.gguf.add_head_count_kv       (params.n_head_kv)
+        self.gguf.add_head_count(params.n_head)
+        self.gguf.add_head_count_kv(params.n_head_kv)
+
+        if params.f_norm_eps is None:
+            raise ValueError("f_norm_eps is None")
+
+        self.gguf.add_layer_norm_rms_eps(params.f_norm_eps)
 
         if params.n_experts:
             self.gguf.add_expert_count(params.n_experts)
@@ -890,11 +1055,6 @@ class OutputFile:
         if params.n_experts_used:
             self.gguf.add_expert_used_count(params.n_experts_used)
 
-        if params.f_norm_eps:
-            self.gguf.add_layer_norm_rms_eps(params.f_norm_eps)
-        else:
-            raise ValueError('f_norm_eps is None')
-
         if params.f_rope_freq_base is not None:
             self.gguf.add_rope_freq_base(params.f_rope_freq_base)
 
@@ -912,18 +1072,44 @@ class OutputFile:
         if params.ftype is not None:
             self.gguf.add_file_type(params.ftype)
 
-    def add_meta_vocab(self, vocab: Vocab) -> None:
+    def handle_tokenizer_model(self, vocab: Vocab) -> str:
+        # Map the vocab types to the supported tokenizer models
+        tokenizer_model = {
+            SentencePieceVocab: "llama",
+            HfVocab: "llama",
+            BpeVocab: "gpt2",
+        }.get(type(vocab))
+
+        # Block if vocab type is not predefined
+        if tokenizer_model is None:
+            raise ValueError("Unknown vocab type: Not supported")
+
+        return tokenizer_model
+
+    def extract_vocabulary_from_model(self, vocab: Vocab) -> Tuple[list, list, list]:
         tokens = []
         scores = []
         toktypes = []
+
         # NOTE: `all_tokens` returns the base vocabulary and added tokens
         for text, score, toktype in vocab.all_tokens():
             tokens.append(text)
             scores.append(score)
             toktypes.append(toktype)
 
-        vocab_type = vocab.get_vocab_type()
-        self.gguf.add_tokenizer_model(vocab_type)
+        return tokens, scores, toktypes
+
+    def add_meta_vocab(self, vocab: Vocab) -> None:
+        # Handle the tokenizer model
+        tokenizer_model = self.handle_tokenizer_model(vocab)
+
+        # Ensure that tokenizer_model is added to the GGUF model
+        self.gguf.add_tokenizer_model(tokenizer_model)
+
+        # Extract model vocabulary for model conversion
+        tokens, scores, toktypes = self.extract_vocabulary_from_model(vocab)
+
+        # Add extracted token information for model conversion
         self.gguf.add_token_list(tokens)
         self.gguf.add_token_scores(scores)
         self.gguf.add_token_types(toktypes)
@@ -933,10 +1119,14 @@ class OutputFile:
 
     def add_tensor_info(self, name: str, tensor: LazyTensor) -> None:
         n_elements = int(np.prod(tensor.shape))
-        raw_dtype = getattr(tensor.data_type, 'ggml_type', None)
-        data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype
+        raw_dtype = getattr(tensor.data_type, "ggml_type", None)
+        data_type = (
+            getattr(tensor.data_type, "quantized_type", None) or tensor.data_type.dtype
+        )
         data_nbytes = tensor.data_type.elements_to_bytes(n_elements)
-        self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype = raw_dtype)
+        self.gguf.add_tensor_info(
+            name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype
+        )
 
     def write_meta(self) -> None:
         self.gguf.write_header_to_file()
@@ -950,11 +1140,14 @@ class OutputFile:
 
     @staticmethod
     def write_vocab_only(
-        fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
+        fname_out: Path,
+        params: Params,
+        vocab: Vocab,
+        svocab: gguf.SpecialVocab,
         endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
         pad_vocab: bool = False,
     ) -> None:
-        check_vocab_size(params, vocab, pad_vocab = pad_vocab)
+        check_vocab_size(params, vocab, pad_vocab=pad_vocab)
 
         of = OutputFile(fname_out, endianess=endianess)
 
@@ -982,12 +1175,17 @@ class OutputFile:
 
     @staticmethod
     def write_all(
-        fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab,
+        fname_out: Path,
+        ftype: GGMLFileType,
+        params: Params,
+        model: LazyModel,
+        vocab: Vocab,
+        svocab: gguf.SpecialVocab,
         concurrency: int = DEFAULT_CONCURRENCY,
         endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
         pad_vocab: bool = False,
     ) -> None:
-        check_vocab_size(params, vocab, pad_vocab = pad_vocab)
+        check_vocab_size(params, vocab, pad_vocab=pad_vocab)
 
         of = OutputFile(fname_out, endianess=endianess)
 
@@ -1004,18 +1202,30 @@ class OutputFile:
         of.write_tensor_info()
 
         # tensor data
-        ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
+        ndarrays_inner = bounded_parallel_map(
+            OutputFile.do_item, model.items(), concurrency=concurrency
+        )
         if ftype == GGMLFileType.MostlyQ8_0:
-            ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, use_processpool_executor = True)
+            ndarrays = bounded_parallel_map(
+                OutputFile.maybe_do_quantize,
+                ndarrays_inner,
+                concurrency=concurrency,
+                max_workers=concurrency,
+                use_processpool_executor=True,
+            )
         else:
             ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
 
         start = time.time()
-        for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
+        for i, ((name, lazy_tensor), ndarray) in enumerate(
+            zip(model.items(), ndarrays)
+        ):
             elapsed = time.time() - start
-            size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
+            size = " x ".join(f"{dim:6d}" for dim in lazy_tensor.shape)
             padi = len(str(len(model)))
-            print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}")
+            print(
+                f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}"
+            )
             of.gguf.write_tensor_data(ndarray)
 
         of.close()
@@ -1145,30 +1355,95 @@ def load_some_model(path: Path) -> ModelPlus:
     return model_plus
 
 
-def find_vocab_file_path(path: Path, vocab_file: str) -> Optional[Path]:
-    path2 = path / vocab_file
-    # Use `.parent` instead of /.. to handle the symlink case better.
-    path3 = path.parent / vocab_file
-
-    if path2.exists():
-        return path2
-    if path3.exists():
-        return path3
+class VocabFactory:
+    def __init__(self, path: Path):
+        self.path = path
+        self.files = {
+            "tokenizer.model": None,
+            "vocab.json": None,
+            "tokenizer.json": None,
+        }
+        self._detect_files()
+
+    def _detect_files(self):
+        for file in self.files.keys():
+            file_path = self.path / file
+            parent_file_path = self.path.parent / file
+            if file_path.exists():
+                self.files[file] = file_path
+            elif parent_file_path.exists():
+                self.files[file] = parent_file_path
+
+    def _select_file(self, vocabtype: Optional[str]) -> Path:
+        if vocabtype in ["spm", "bpe"]:
+            # For SentencePiece and BPE, return specific files as before
+            file_key = "tokenizer.model" if vocabtype == "spm" else "vocab.json"
+            if self.files[file_key]:
+                return self.files[file_key]
+            else:
+                raise FileNotFoundError(f"{vocabtype} {file_key} not found.")
+        elif vocabtype == "hfft":
+            # For Hugging Face Fast Tokenizer, return the directory path instead of a specific file
+            return self.path
+        else:
+            raise ValueError(f"Unsupported vocabulary type {vocabtype}")
+
+    def _create_special_vocab(
+        self,
+        vocab: Vocab,
+        vocabtype: str,
+        model_parent_path: Path,
+    ) -> gguf.SpecialVocab:
+        load_merges = vocabtype == "bpe"
+        n_vocab = vocab.vocab_size if hasattr(vocab, "vocab_size") else None
+        return gguf.SpecialVocab(
+            model_parent_path,
+            load_merges=load_merges,
+            special_token_types=None,  # Predetermined or passed as a parameter
+            n_vocab=n_vocab,
+        )
 
-    return None
+    def load_vocab(
+        self, vocabtype: str, model_parent_path: Path
+    ) -> Tuple[Vocab, gguf.SpecialVocab]:
+        path = self._select_file(vocabtype)
+        print(f"Loading vocab file '{path}', type '{vocabtype}'")
+
+        added_tokens_path = path.parent / "added_tokens.json"
+        if vocabtype == "bpe":
+            vocab = BpeVocab(
+                path, added_tokens_path if added_tokens_path.exists() else None
+            )
+        elif vocabtype == "spm":
+            vocab = SentencePieceVocab(
+                path, added_tokens_path if added_tokens_path.exists() else None
+            )
+        elif vocabtype == "hfft":
+            vocab = HfVocab(
+                path, added_tokens_path if added_tokens_path.exists() else None
+            )
+        else:
+            raise ValueError(f"Unsupported vocabulary type {vocabtype}")
+        special_vocab = self._create_special_vocab(
+            vocab,
+            vocabtype,
+            model_parent_path,
+        )
+        return vocab, special_vocab
 
 
-def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path:
+def default_output_file(model_paths: list[Path], file_type: GGMLFileType) -> Path:
     namestr = {
-        GGMLFileType.AllF32:    "f32",
+        GGMLFileType.AllF32: "f32",
         GGMLFileType.MostlyF16: "f16",
-        GGMLFileType.MostlyQ8_0:"q8_0",
+        GGMLFileType.MostlyQ8_0: "q8_0",
     }[file_type]
     ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf"
     if ret in model_paths:
         sys.stderr.write(
             f"Error: Default output path ({ret}) would overwrite the input. "
-            "Please explicitly specify a path using --outfile.\n")
+            "Please explicitly specify a path using --outfile.\n"
+        )
         sys.exit(1)
     return ret
 
@@ -1178,32 +1453,111 @@ def do_dump_model(model_plus: ModelPlus) -> None:
     print(f"model_plus.format = {model_plus.format!r}")
     print(f"model_plus.vocab = {model_plus.vocab!r}")
     for name, lazy_tensor in model_plus.model.items():
-        print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}")
+        print(
+            f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}"
+        )
 
 
-def main(args_in: list[str] | None = None) -> None:
+def get_argument_parser() -> ArgumentParser:
     output_choices = ["f32", "f16"]
     if np.uint32(1) == np.uint32(1).newbyteorder("<"):
         # We currently only support Q8_0 output on little endian systems.
         output_choices.append("q8_0")
-    parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file")
-    parser.add_argument("--awq-path",    type=Path,              help="Path to scale awq cache file", default=None)
-    parser.add_argument("--dump",        action="store_true",    help="don't convert, just show what's in the model")
-    parser.add_argument("--dump-single", action="store_true",    help="don't convert, just show what's in a single model file")
-    parser.add_argument("--vocab-only",  action="store_true",    help="extract only the vocab")
-    parser.add_argument("--outtype",     choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
-    parser.add_argument("--vocab-dir",   type=Path,              help="directory containing tokenizer.model, if separate from model file")
-    parser.add_argument("--outfile",     type=Path,              help="path to write to; default: based on input")
-    parser.add_argument("model",         type=Path,              help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
-    parser.add_argument("--ctx",         type=int,               help="model training context (default: based on input)")
-    parser.add_argument("--concurrency", type=int,               help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
-    parser.add_argument("--bigendian",   action="store_true",    help="model is executed on big endian machine")
-    parser.add_argument("--padvocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
-
-    args = parser.parse_args(args_in)
+
+    parser = argparse.ArgumentParser(
+        description="Convert a LLaMa model to a GGML compatible file"
+    )
+
+    parser.add_argument(
+        "model",
+        type=Path,
+        help="Directory containing the model file or the model file itself (*.pth, *.pt, *.bin)",
+    )
+
+    parser.add_argument(
+        "--awq-path",
+        type=Path,
+        help="Path to the Activation-aware Weight Quantization cache file",
+        default=None,
+    )
+
+    parser.add_argument(
+        "--dump",
+        action="store_true",
+        help="Display the model content without converting it",
+    )
+
+    parser.add_argument(
+        "--dump-single",
+        action="store_true",
+        help="Display the content of a single model file without conversion",
+    )
+
+    parser.add_argument(
+        "--vocab-only",
+        action="store_true",
+        help="Extract and output only the vocabulary",
+    )
+
+    parser.add_argument(
+        "--outtype",
+        choices=output_choices,
+        help="Output format - note: q8_0 may be very slow (default: f16 or f32 based on input)",
+    )
+
+    parser.add_argument(
+        "--vocab-dir",
+        type=Path,
+        help="Directory containing the tokenizer.model, if separate from the model file",
+    )
+
+    parser.add_argument(
+        "--vocab-type",
+        choices=["spm", "bpe", "hfft"],  # hfft: Hugging Face Fast Tokenizer
+        default="spm",
+        help="The vocabulary format used to define the tokenizer model (default: spm)",
+    )
+
+    parser.add_argument(
+        "--pad-vocab",
+        action="store_true",
+        help="Add padding tokens when the model's vocabulary size exceeds the tokenizer metadata",
+    )
+
+    parser.add_argument(
+        "--outfile",
+        type=Path,
+        help="Specify the path for the output file (default is based on input)",
+    )
+
+    parser.add_argument(
+        "--ctx", type=int, help="Model training context (default is based on input)"
+    )
+
+    parser.add_argument(
+        "--concurrency",
+        type=int,
+        help=f"Concurrency used for conversion (default: {DEFAULT_CONCURRENCY})",
+        default=DEFAULT_CONCURRENCY,
+    )
+
+    parser.add_argument(
+        "--big-endian",
+        action="store_true",
+        help="Indicate that the model is executed on a big-endian machine",
+    )
+
+    return parser
+
+
+def main(argv: Optional[list[str]] = None) -> None:
+    parser = get_argument_parser()
+    args = parser.parse_args(argv)
+
     if args.awq_path:
-        sys.path.insert(1, str(Path(__file__).parent / 'awq-py'))
+        sys.path.insert(1, str(Path(__file__).resolve().parent / "awq-py"))
         from awq.apply_awq import add_scale_weights
+
         tmp_model_path = args.model / "weighted_model"
         if tmp_model_path.is_dir():
             print(f"{tmp_model_path} exists as a weighted model.")
@@ -1222,22 +1576,27 @@ def main(args_in: list[str] | None = None) -> None:
     if not args.vocab_only:
         model_plus = load_some_model(args.model)
     else:
-        model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
+        model_plus = ModelPlus(
+            model={}, paths=[args.model / "dummy"], format="none", vocab=None
+        )
 
     if args.dump:
         do_dump_model(model_plus)
         return
+
     endianess = gguf.GGUFEndian.LITTLE
-    if args.bigendian:
+    if args.big_endian:
         endianess = gguf.GGUFEndian.BIG
 
     params = Params.load(model_plus)
     if params.n_ctx == -1:
         if args.ctx is None:
-            raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n"
-                            "Please specify one with --ctx:\n"
-                            " - LLaMA v1: --ctx 2048\n"
-                            " - LLaMA v2: --ctx 4096\n")
+            raise Exception(
+                "The model doesn't have a context size, and you didn't specify one with --ctx\n"
+                "Please specify one with --ctx:\n"
+                " - LLaMA v1: --ctx 2048\n"
+                " - LLaMA v2: --ctx 4096\n"
+            )
         params.n_ctx = args.ctx
 
     if args.outtype:
@@ -1249,47 +1608,51 @@ def main(args_in: list[str] | None = None) -> None:
 
     print(f"params = {params}")
 
-    vocab: Vocab
+    model_parent_path = model_plus.paths[0].parent
+    vocab_path = Path(args.vocab_dir or args.model or model_parent_path)
+    vocab_factory = VocabFactory(vocab_path)
+    vocab, special_vocab = vocab_factory.load_vocab(args.vocab_type, model_parent_path)
+
     if args.vocab_only:
         if not args.outfile:
             raise ValueError("need --outfile if using --vocab-only")
-        # FIXME: Try to respect vocab_dir somehow?
-        vocab = VocabLoader(params, args.vocab_dir or args.model)
-        special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
-                                          load_merges = True,
-                                          n_vocab = vocab.vocab_size)
         outfile = args.outfile
-        OutputFile.write_vocab_only(outfile, params, vocab, special_vocab,
-                                    endianess = endianess, pad_vocab = args.padvocab)
+        OutputFile.write_vocab_only(
+            outfile,
+            params,
+            vocab,
+            special_vocab,
+            endianess=endianess,
+            pad_vocab=args.pad_vocab,
+        )
         print(f"Wrote {outfile}")
         return
 
     if model_plus.vocab is not None and args.vocab_dir is None:
         vocab = model_plus.vocab
-    else:
-        vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
-        vocab = VocabLoader(params, vocab_dir)
-
-    # FIXME: Try to respect vocab_dir somehow?
-    print(f"Vocab info: {vocab}")
-    special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
-                                      load_merges = True,
-                                      n_vocab = vocab.vocab_size)
-
-    print(f"Special vocab info: {special_vocab}")
-    model   = model_plus.model
-    model   = convert_model_names(model, params)
-    ftype   = pick_output_type(model, args.outtype)
-    model   = convert_to_output_type(model, ftype)
-    outfile = args.outfile or default_outfile(model_plus.paths, ftype)
+
+    model = model_plus.model
+    model = convert_model_names(model, params)
+    ftype = pick_output_type(model, args.outtype)
+    model = convert_to_output_type(model, ftype)
+    outfile = args.outfile or default_output_file(model_plus.paths, ftype)
 
     params.ftype = ftype
     print(f"Writing {outfile}, format {ftype}")
 
-    OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
-                         concurrency = args.concurrency, endianess = endianess, pad_vocab = args.padvocab)
+    OutputFile.write_all(
+        outfile,
+        ftype,
+        params,
+        model,
+        vocab,
+        special_vocab,
+        concurrency=args.concurrency,
+        endianess=endianess,
+        pad_vocab=args.pad_vocab,
+    )
     print(f"Wrote {outfile}")
 
 
-if __name__ == '__main__':
-    main()
+if __name__ == "__main__":
+    main(sys.argv[1:])  # Exclude the first element (script name) from sys.argv