]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add byte token type when tokenizer.model is not exists (#4641)
authorwonjun Jang <redacted>
Wed, 27 Dec 2023 08:37:25 +0000 (17:37 +0900)
committerGitHub <redacted>
Wed, 27 Dec 2023 08:37:25 +0000 (17:37 +0900)
* Add byte token type to hf format

* remove unused variable

convert.py

index 7a3cd615e9775e0013f44b54c358eeb44020bb12..1f0c4f2f4197751660887be054447e87fb87a6a4 100755 (executable)
@@ -357,6 +357,7 @@ class VocabLoader:
             for tok in self.tokenizer.all_special_tokens
         }
         self.special_ids: set[int] = set(self.tokenizer.all_special_ids)
+        self.reverse_vocab = {id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()}
         self.vocab_size_base: int = self.tokenizer.vocab_size
         self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_dict)
         self.fname_tokenizer: Path = fname_tokenizer
@@ -370,15 +371,13 @@ class VocabLoader:
             self.spm = None
 
     def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
-        tokenizer = self.tokenizer
-        reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.get_vocab().items()}
         added_tokens_ids = set(self.added_tokens_dict.values())
 
         for i in range(self.vocab_size_base):
             if i in added_tokens_ids:
                 continue
 
-            text = reverse_vocab[i].encode("utf-8")
+            text = self.reverse_vocab[i].encode("utf-8")
             yield text, self.get_token_score(i), self.get_token_type(i)
 
     def get_token_type(self, token_id: int) -> gguf.TokenType:
@@ -394,10 +393,13 @@ class VocabLoader:
             if self.spm.is_byte(token_id):
                 toktype = gguf.TokenType.BYTE
         else:
+            token = self.reverse_vocab[token_id]
             if token_id == self.unk_token_id:
                 toktype = gguf.TokenType.UNKNOWN
-            if token_id in self.special_ids:
+            elif token_id in self.special_ids:
                 toktype = gguf.TokenType.CONTROL
+            elif len(token) == 6 and token.startswith("<0x") and token.endswith(">"):
+                toktype = gguf.TokenType.BYTE
 
         return toktype