]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
convert : support loading vocab from fast tokenizer config (#3633)
authorwonjun Jang <redacted>
Thu, 14 Dec 2023 08:09:34 +0000 (17:09 +0900)
committerGitHub <redacted>
Thu, 14 Dec 2023 08:09:34 +0000 (10:09 +0200)
* Add HFVocab into convert.py

* Update convert.py

* Update convert.py

* add bytes_to_unicode function

* change add_meta_vocab fucntion

* remove debug code

* remove byte_encoder

* Add newline between classes

* Check tokenizer.json when tokenizer.model is not exist.

* Move transformers dependency to local code

* Add error context with 'raise from'

* Add fast tokenizer option to BpeVocab

* Update convert.py

* Add VocabLoader and remove *Vocab class

* Add transformers dependency

* remove added tokens and check newline token to decide spm or bpe

* Update convert.py

* Add special token type

* Update convert.py

* Update convert.py

* Update convert.py

* Fix typo in convert.py

* Fix when params.n_vocab < tokenizer vocab size

* update vocab class

* change funtion name

* Remove unused variable/functions, add types to class variable and methods, delete blank liens

* fix flake8 warnings

* code style cleanup

* make mypy happy

* change exception

---------

Co-authored-by: Jared Van Bortel <redacted>
convert.py
requirements.txt

index e4b69d172f728bda274889e38defded1fef2ac4b..7a3cd615e9775e0013f44b54c358eeb44020bb12 100755 (executable)
@@ -10,6 +10,7 @@ import itertools
 import json
 import math
 import mmap
+import os
 import pickle
 import re
 import signal
@@ -18,15 +19,15 @@ import sys
 import time
 import zipfile
 from abc import ABCMeta, abstractmethod
+from collections import OrderedDict
 from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
 from dataclasses import dataclass
 from pathlib import Path
-from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar
+from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, Optional, TypeVar, cast
 
 import numpy as np
 from sentencepiece import SentencePieceProcessor
 
-import os
 if 'NO_LOCAL_GGUF' not in os.environ:
     sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
 import gguf
@@ -327,127 +328,138 @@ class Params:
         return params
 
 
-#
-# vocab
-#
+class VocabLoader:
+    def __init__(self, params: Params, fname_tokenizer: Path) -> None:
+        try:
+            from transformers import AutoTokenizer
+        except ImportError as e:
+            raise ImportError(
+                "To use VocabLoader, please install the `transformers` package. "
+                "You can install it with `pip install transformers`."
+            ) from e
 
-class BpeVocab:
-    def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
-        self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read())
-        added_tokens: dict[str, int]
-        if fname_added_tokens is not None:
-            # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab.
-            added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
-        else:
-            # Fall back to trying to find the added tokens in tokenizer.json
-            tokenizer_json_file = fname_tokenizer.parent / 'tokenizer.json'
-            if not tokenizer_json_file.is_file():
-                added_tokens = {}
-            else:
-                tokenizer_json = json.load(open(tokenizer_json_file, encoding="utf-8"))
-                added_tokens = dict(
-                    (item['content'], item['id'])
-                    for item in tokenizer_json.get('added_tokens', [])
-                    # Added tokens here can be duplicates of the main vocabulary.
-                    if item['content'] not in self.bpe_tokenizer)
-
-        vocab_size: int = len(self.bpe_tokenizer)
-        expected_ids    = list(range(vocab_size, vocab_size + len(added_tokens)))
-        actual_ids      = sorted(added_tokens.values())
-        if expected_ids != actual_ids:
-            expected_end_id = vocab_size + len(actual_ids) - 1
-            raise Exception(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}")
-
-        items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
-        self.added_tokens_list    = [text for (text, idx) in items]
-        self.vocab_size_base: int = vocab_size
-        self.vocab_size: int      = self.vocab_size_base + len(self.added_tokens_list)
-        self.fname_tokenizer      = fname_tokenizer
-        self.fname_added_tokens   = fname_added_tokens
-
-    def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
-        tokenizer = self.bpe_tokenizer
-        reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.items()}
-
-        for i, _ in enumerate(tokenizer):
-            yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL
+        try:
+            self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), trust_remote_code=True)
+        except ValueError:
+            self.tokenizer = AutoTokenizer.from_pretrained(str(fname_tokenizer), use_fast=False, trust_remote_code=True)
 
-    def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
-        for text in self.added_tokens_list:
-            score = -1000.0
-            yield text.encode("utf-8"), score, gguf.TokenType.CONTROL
+        self.added_tokens_dict: OrderedDict[str, int] = OrderedDict()
 
-    def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
-        yield from self.bpe_tokens()
-        yield from self.added_tokens()
+        for tok, tokidx in sorted(self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]):
+            if tokidx >= params.n_vocab or tokidx < self.tokenizer.vocab_size:
+                continue
 
-    def __repr__(self) -> str:
-        return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
+            self.added_tokens_dict[tok] = tokidx
 
+        self.unk_token_id: int = self.tokenizer.unk_token_id
+        self.specials: dict[str, int] = {
+            tok: self.tokenizer.get_vocab()[tok]
+            for tok in self.tokenizer.all_special_tokens
+        }
+        self.special_ids: set[int] = set(self.tokenizer.all_special_ids)
+        self.vocab_size_base: int = self.tokenizer.vocab_size
+        self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_dict)
+        self.fname_tokenizer: Path = fname_tokenizer
 
-class SentencePieceVocab:
-    def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None:
-        self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
-        added_tokens: dict[str, int]
-        if fname_added_tokens is not None:
-            added_tokens = json.load(open(fname_added_tokens, encoding="utf-8"))
+        vocab_file = "tokenizer.model"
+        path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
+        if path_candidate is not None:
+            self.spm = SentencePieceProcessor(str(path_candidate))
+            print(self.spm.vocab_size(), self.vocab_size_base)
         else:
-            added_tokens = {}
-
-        vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
+            self.spm = None
 
-        new_tokens       = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
-        expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
-        actual_new_ids   = sorted(new_tokens.keys())
+    def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
+        tokenizer = self.tokenizer
+        reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.get_vocab().items()}
+        added_tokens_ids = set(self.added_tokens_dict.values())
 
-        if expected_new_ids != actual_new_ids:
-            raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
+        for i in range(self.vocab_size_base):
+            if i in added_tokens_ids:
+                continue
 
-        # Token pieces that were added to the base vocabulary.
-        self.added_tokens_list  = [new_tokens[id] for id in actual_new_ids]
-        self.vocab_size_base    = vocab_size
-        self.vocab_size         = self.vocab_size_base + len(self.added_tokens_list)
-        self.fname_tokenizer    = fname_tokenizer
-        self.fname_added_tokens = fname_added_tokens
+            text = reverse_vocab[i].encode("utf-8")
+            yield text, self.get_token_score(i), self.get_token_type(i)
 
-    def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
-        tokenizer = self.sentencepiece_tokenizer
-        for i in range(tokenizer.vocab_size()):
-            piece = tokenizer.id_to_piece(i)
-            text: bytes = piece.encode("utf-8")
-            score: float = tokenizer.get_score(i)
+    def get_token_type(self, token_id: int) -> gguf.TokenType:
+        toktype = gguf.TokenType.NORMAL
 
-            toktype = gguf.TokenType.NORMAL
-            if tokenizer.is_unknown(i):
+        if self.spm is not None and token_id < self.spm.vocab_size():
+            if self.spm.is_unknown(token_id):
                 toktype = gguf.TokenType.UNKNOWN
-            if tokenizer.is_control(i):
+            if self.spm.is_control(token_id):
                 toktype = gguf.TokenType.CONTROL
-
-            # NOTE: I think added_tokens are user defined.
-            # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
-            # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED
-
-            if tokenizer.is_unused(i):
+            if self.spm.is_unused(token_id):
                 toktype = gguf.TokenType.UNUSED
-            if tokenizer.is_byte(i):
+            if self.spm.is_byte(token_id):
                 toktype = gguf.TokenType.BYTE
+        else:
+            if token_id == self.unk_token_id:
+                toktype = gguf.TokenType.UNKNOWN
+            if token_id in self.special_ids:
+                toktype = gguf.TokenType.CONTROL
 
-            yield text, score, toktype
+        return toktype
+
+    def get_token_score(self, token_id: int) -> float:
+        if self.spm is not None and token_id < self.spm.vocab_size():
+            return cast(float, self.spm.get_score(token_id))
+        return 0.0
 
     def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
-        for text in self.added_tokens_list:
-            score = -1000.0
-            yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED
+
+        for text in self.added_tokens_dict:
+            if text in self.specials:
+
+                toktype = self.get_token_type(self.specials[text])
+                score = self.get_token_score(self.specials[text])
+
+            else:
+                toktype = gguf.TokenType.USER_DEFINED
+                score = -1000.0
+
+            yield text.encode("utf-8"), score, toktype
+
+    def has_newline_token(self) -> bool:
+        return '<0x0A>' in self.tokenizer.vocab or '\n' in self.tokenizer.vocab
 
     def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
-        yield from self.sentencepiece_tokens()
+        yield from self.hf_tokens()
         yield from self.added_tokens()
 
+    def get_vocab_type(self) -> str:
+        path_candidates = []
+        vocab_file = "tokenizer.model"
+        path_candidates.append(vocab_file)
+        path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
+        if path_candidate is not None:
+            return "llama"
+
+        vocab_file = "vocab.json"
+        path_candidates.append(vocab_file)
+        path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
+        if path_candidate is not None:
+            return "gpt2"
+
+        vocab_file = "tokenizer.json"
+        path_candidates.append(vocab_file)
+        path_candidate = find_vocab_file_path(self.fname_tokenizer, vocab_file)
+        if path_candidate:
+            if not self.has_newline_token():
+                return "gpt2"
+            return "llama"
+
+        raise FileNotFoundError(
+            f"Could not find {path_candidates} in {self.fname_tokenizer} or its parent; "
+            "if it's in another directory, pass the directory as --vocab-dir"
+        )
+
     def __repr__(self) -> str:
-        return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
+        return f"<VocabLoader with {self.vocab_size_base} base tokens and {len(self.added_tokens_dict)} added tokens>"
+
 
+Vocab: TypeAlias = 'VocabLoader'
 
-Vocab: TypeAlias = 'BpeVocab | SentencePieceVocab'
 
 #
 # data loading
@@ -824,20 +836,27 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
             yield result
 
 
-def check_vocab_size(params: Params, vocab: Vocab) -> None:
+def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None:
     if params.n_vocab != vocab.vocab_size:
-        assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab)
-        if params.n_vocab == vocab.vocab_size_base:
+        if params.n_vocab == vocab.vocab_size:
             print("Ignoring added_tokens.json since model matches vocab size without it.")
-            vocab.added_tokens_list = []
-            vocab.vocab_size = vocab.vocab_size_base
+            vocab.added_tokens_dict = OrderedDict()
+            vocab.vocab_size = vocab.vocab_size
+            return
+
+        if pad_vocab and params.n_vocab > vocab.vocab_size:
+            pad_count = params.n_vocab - vocab.vocab_size
+            print(f'Padding vocab with {pad_count} token(s) - <dummy00001> through <dummy{pad_count:05}>')
+            for i in range(1, (params.n_vocab - vocab.vocab_size) + 1):
+                vocab.added_tokens_dict[f'<dummy{i:05}>'] = -1
+            vocab.vocab_size = params.n_vocab
             return
         msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer}"
-        if vocab.fname_added_tokens is not None:
-            msg += f" combined with {vocab.fname_added_tokens}"
         msg += f" has {vocab.vocab_size})."
-        if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20 and vocab.fname_added_tokens is None:
+        if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20:
             msg += f"  Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})."
+        if vocab.vocab_size < params.n_vocab:
+            msg += " Possibly try using the --padvocab option."
         raise Exception(msg)
 
 
@@ -901,12 +920,8 @@ class OutputFile:
             scores.append(score)
             toktypes.append(toktype)
 
-        if isinstance(vocab, SentencePieceVocab):
-            self.gguf.add_tokenizer_model("llama")
-        elif isinstance(vocab, BpeVocab):
-            self.gguf.add_tokenizer_model("gpt2")
-        else:
-            raise ValueError('Unknown vocab type: Not BpeVocab or SentencePieceVocab')
+        vocab_type = vocab.get_vocab_type()
+        self.gguf.add_tokenizer_model(vocab_type)
         self.gguf.add_token_list(tokens)
         self.gguf.add_token_scores(scores)
         self.gguf.add_token_types(toktypes)
@@ -932,8 +947,12 @@ class OutputFile:
         self.gguf.close()
 
     @staticmethod
-    def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None:
-        check_vocab_size(params, vocab)
+    def write_vocab_only(
+        fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
+        endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
+        pad_vocab: bool = False,
+    ) -> None:
+        check_vocab_size(params, vocab, pad_vocab = pad_vocab)
 
         of = OutputFile(fname_out, endianess=endianess)
 
@@ -960,8 +979,13 @@ class OutputFile:
         return dt.quantize(arr)
 
     @staticmethod
-    def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None:
-        check_vocab_size(params, vocab)
+    def write_all(
+        fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab,
+        concurrency: int = DEFAULT_CONCURRENCY,
+        endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
+        pad_vocab: bool = False,
+    ) -> None:
+        check_vocab_size(params, vocab, pad_vocab = pad_vocab)
 
         of = OutputFile(fname_out, endianess=endianess)
 
@@ -1119,35 +1143,17 @@ def load_some_model(path: Path) -> ModelPlus:
     return model_plus
 
 
-def load_vocab(path: Path, vocabtype: str | None) -> Vocab:
-    # Be extra-friendly and accept either a file or a directory.  Also, if it's
-    # a directory, it might be the model directory, and tokenizer.model might
-    # be in the parent of that.
-    if path.is_dir():
-        vocab_file = "tokenizer.model"
-        if vocabtype == 'bpe':
-            vocab_file = "vocab.json"
-        path2 = path / vocab_file
-        # Use `.parent` instead of /.. to handle the symlink case better.
-        path3 = path.parent / vocab_file
-        if path2.exists():
-            path = path2
-        elif path3.exists():
-            path = path3
-        else:
-            raise FileNotFoundError(
-                f"Could not find {vocab_file} in {path} or its parent; "
-                "if it's in another directory, pass the directory as --vocab-dir")
+def find_vocab_file_path(path: Path, vocab_file: str) -> Optional[Path]:
+    path2 = path / vocab_file
+    # Use `.parent` instead of /.. to handle the symlink case better.
+    path3 = path.parent / vocab_file
 
-    print(f"Loading vocab file '{path}', type '{vocabtype}'")
+    if path2.exists():
+        return path2
+    if path3.exists():
+        return path3
 
-    added_tokens_path = path.parent / "added_tokens.json"
-    if vocabtype == "bpe":
-        return BpeVocab(path, added_tokens_path if added_tokens_path.exists() else None)
-    elif vocabtype == "spm":
-        return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
-    else:
-        raise ValueError(f"Unsupported vocabulary type {vocabtype}")
+    return None
 
 
 def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path:
@@ -1185,11 +1191,11 @@ def main(args_in: list[str] | None = None) -> None:
     parser.add_argument("--outtype",     choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
     parser.add_argument("--vocab-dir",   type=Path,              help="directory containing tokenizer.model, if separate from model file")
     parser.add_argument("--outfile",     type=Path,              help="path to write to; default: based on input")
-    parser.add_argument("model",         type=Path,              help="directory containing model file, or model file itself (*.pth, *.pt, *.bin, *.safetensors)")
-    parser.add_argument("--vocabtype",   choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
+    parser.add_argument("model",         type=Path,              help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
     parser.add_argument("--ctx",         type=int,               help="model training context (default: based on input)")
     parser.add_argument("--concurrency", type=int,               help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
     parser.add_argument("--bigendian",   action="store_true",    help="model is executed on big endian machine")
+    parser.add_argument("--padvocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides")
 
     args = parser.parse_args(args_in)
     if args.dump_single:
@@ -1232,12 +1238,13 @@ def main(args_in: list[str] | None = None) -> None:
         if not args.outfile:
             raise ValueError("need --outfile if using --vocab-only")
         # FIXME: Try to respect vocab_dir somehow?
-        vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
+        vocab = VocabLoader(params, args.vocab_dir or args.model)
         special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
-                                          load_merges = args.vocabtype == 'bpe',
+                                          load_merges = True,
                                           n_vocab = vocab.vocab_size)
         outfile = args.outfile
-        OutputFile.write_vocab_only(outfile, params, vocab, special_vocab)
+        OutputFile.write_vocab_only(outfile, params, vocab, special_vocab,
+                                    endianess = endianess, pad_vocab = args.padvocab)
         print(f"Wrote {outfile}")
         return
 
@@ -1245,12 +1252,15 @@ def main(args_in: list[str] | None = None) -> None:
         vocab = model_plus.vocab
     else:
         vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
-        vocab = load_vocab(vocab_dir, args.vocabtype)
+        vocab = VocabLoader(params, vocab_dir)
+
     # FIXME: Try to respect vocab_dir somehow?
+    print(f"Vocab info: {vocab}")
     special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
-                                      load_merges = args.vocabtype == 'bpe',
+                                      load_merges = True,
                                       n_vocab = vocab.vocab_size)
 
+    print(f"Special vocab info: {special_vocab}")
     model   = model_plus.model
     model   = convert_model_names(model, params)
     ftype   = pick_output_type(model, args.outtype)
@@ -1260,7 +1270,8 @@ def main(args_in: list[str] | None = None) -> None:
     params.ftype = ftype
     print(f"Writing {outfile}, format {ftype}")
 
-    OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, concurrency = args.concurrency, endianess=endianess)
+    OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
+                         concurrency = args.concurrency, endianess = endianess, pad_vocab = args.padvocab)
     print(f"Wrote {outfile}")
 
 
index 81c909d0ba7fe6e5115945d83a54d57f9a849f89..badfec3be804cce60b87fed472e5807b44246fc9 100644 (file)
@@ -1,3 +1,4 @@
 numpy==1.24.4
 sentencepiece==0.1.98
+transformers>=4.34.0
 gguf>=0.1.0