]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix pre-tokenization of non-special added tokens (#8228)
authorcompilade <redacted>
Sun, 14 Jul 2024 03:35:10 +0000 (23:35 -0400)
committerGitHub <redacted>
Sun, 14 Jul 2024 03:35:10 +0000 (23:35 -0400)
* llama : fix mpt and olmo pre-tokenizer

* llama : pre-tokenize non-special user-defined tokens first

* llama : fix detection of control-like user-defined tokens

* convert_hf : identify which user-defined tokens are control tokens

Only used in _set_vocab_gpt2() for now.

* convert_hf : identify more added control tokens for SPM tokenziers

This makes Gemma and Gemma-2 tokenize pretty much EVERYTHING correctly,
including HTML tags and consecutive spaces,
but it unfortunately requires model re-conversion.

There seems to be a weird behavior of the HF tokenizer for Gemma,
which prefers to use the 16-space token over more lengthy space tokens,
while using the SentencePiece tokenizer does not do this.
(the implementation in llama.cpp has the same behavior as SentencePiece)

* llama : fix wrong pre-tokenization of byte tokens

* llama : fix Viking pre-tokenizer regex

The order was previously wrong, which caused errors in some tests.

* llama : fix command-r detokenization

* convert_hf : reduce usages of the UNKNOWN token type

* llama : add UNKNOWN tokens in the special tokens cache

* convert_hf : reduce usages of UNKNOWN for InternLM2

This makes the changes from #8321 more consistent
with the other changes made here.

* test-tokenizer-random : reduce potential confilcts with #8379

* test-tokenizer-random : add a failing edge case for falcon

convert_hf_to_gguf.py
src/llama.cpp
tests/test-tokenizer-0.cpp
tests/test-tokenizer-random.py

index cf930be17a6e075405e4569e2807788e0ddac11b..af82cd6cd05c996c18ebc195df71a75340c3e082 100755 (executable)
@@ -373,6 +373,29 @@ class Model:
         except KeyError:
             raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
 
+    def does_token_look_special(self, token: str | bytes) -> bool:
+        if isinstance(token, (bytes, bytearray)):
+            token_text = token.decode(encoding="utf-8")
+        elif isinstance(token, memoryview):
+            token_text = token.tobytes().decode(encoding="utf-8")
+        else:
+            token_text = token
+
+        # Some models mark some added tokens which ought to be control tokens as not special.
+        # (e.g. command-r, command-r-plus, deepseek-coder, gemma{,-2})
+        seems_special = token_text in (
+            "<pad>",  # deepseek-coder
+            "<mask>", "<2mass>", "[@BOS@]",  # gemma{,-2}
+        )
+
+        seems_special = seems_special or (token_text.startswith("<|") and token_text.endswith("|>"))
+        seems_special = seems_special or (token_text.startswith("<|") and token_text.endswith("|>"))  # deepseek-coder
+
+        # TODO: should these be marked as UNUSED instead? (maybe not)
+        seems_special = seems_special or (token_text.startswith("<unused") and token_text.endswith(">"))  # gemma{,-2}
+
+        return seems_special
+
     # used for GPT-2 BPE and WordPiece vocabs
     def get_vocab_base(self) -> tuple[list[str], list[int], str]:
         tokens: list[str] = []
@@ -391,16 +414,18 @@ class Model:
         for i in range(vocab_size):
             if i not in reverse_vocab:
                 tokens.append(f"[PAD{i}]")
-                toktypes.append(gguf.TokenType.USER_DEFINED)
-            elif reverse_vocab[i] in added_vocab:
-                tokens.append(reverse_vocab[i])
-                if tokenizer.added_tokens_decoder[i].special:
-                    toktypes.append(gguf.TokenType.CONTROL)
-                else:
-                    toktypes.append(gguf.TokenType.USER_DEFINED)
+                toktypes.append(gguf.TokenType.UNUSED)
             else:
-                tokens.append(reverse_vocab[i])
-                toktypes.append(gguf.TokenType.NORMAL)
+                token: str = reverse_vocab[i]
+                if token in added_vocab:
+                    if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token):
+                        toktypes.append(gguf.TokenType.CONTROL)
+                    else:
+                        token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ")  # pre-normalize user-defined spaces
+                        toktypes.append(gguf.TokenType.USER_DEFINED)
+                else:
+                    toktypes.append(gguf.TokenType.NORMAL)
+                tokens.append(token)
 
         return tokens, toktypes, tokpre
 
@@ -559,7 +584,7 @@ class Model:
         for i in range(vocab_size):
             if i not in reverse_vocab:
                 tokens.append(f"[PAD{i}]")
-                toktypes.append(gguf.TokenType.USER_DEFINED)
+                toktypes.append(gguf.TokenType.UNUSED)
             elif reverse_vocab[i] in added_vocab:
                 tokens.append(reverse_vocab[i])
                 toktypes.append(gguf.TokenType.CONTROL)
@@ -609,7 +634,7 @@ class Model:
 
         tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
         scores: list[float] = [-10000.0] * vocab_size
-        toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
+        toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
 
         for token_id in range(tokenizer.vocab_size()):
             piece = tokenizer.IdToPiece(token_id)
@@ -644,6 +669,25 @@ class Model:
                     scores[token_id] = -1000.0
                     toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
 
+        tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
+        if tokenizer_config_file.is_file():
+            with open(tokenizer_config_file, "r", encoding="utf-8") as f:
+                tokenizer_config_json = json.load(f)
+                added_tokens_decoder = tokenizer_config_json.get("added_tokens_decoder", {})
+                for token_id, token_data in added_tokens_decoder.items():
+                    token_id = int(token_id)
+                    token: str = token_data["content"]
+                    if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
+                        assert tokens[token_id] == token.encode("utf-8")
+                    if token_data.get("special") or self.does_token_look_special(token):
+                        toktypes[token_id] = SentencePieceTokenTypes.CONTROL
+                    else:
+                        token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ")  # pre-normalize user-defined spaces
+                        toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
+
+                    scores[token_id] = -1000.0
+                    tokens[token_id] = token.encode("utf-8")
+
         if vocab_size > len(tokens):
             pad_count = vocab_size - len(tokens)
             logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
@@ -1266,7 +1310,7 @@ class StableLMModel(Model):
         if (self.dir_model / "tokenizer.json").is_file():
             self._set_vocab_gpt2()
         else:
-            # StableLM 2 1.6B uses a vocab in a similar format to Qwen's vocab
+            # StableLM 2 1.6B used to have a vocab in a similar format to Qwen's vocab
             self._set_vocab_qwen()
 
     def set_gguf_parameters(self):
@@ -1578,7 +1622,6 @@ class DbrxModel(Model):
         self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"])
 
         self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"])
-        self.gguf_writer.add_file_type(self.ftype)
 
         self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"])
         self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"])
@@ -1872,7 +1915,7 @@ class Phi3MiniModel(Model):
 
         tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
         scores: list[float] = [-10000.0] * vocab_size
-        toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
+        toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
 
         for token_id in range(tokenizer.vocab_size()):
 
@@ -1917,7 +1960,7 @@ class Phi3MiniModel(Model):
                 for token_id, foken_data in added_tokens_decoder.items():
                     token_id = int(token_id)
                     token = foken_data["content"].encode("utf-8")
-                    if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN:
+                    if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
                         assert tokens[token_id] == token
                     tokens[token_id] = token
                     scores[token_id] = -1000.0
@@ -1933,7 +1976,7 @@ class Phi3MiniModel(Model):
                 for foken_data in added_tokens:
                     token_id = int(foken_data["id"])
                     token = foken_data["content"].encode("utf-8")
-                    if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN:
+                    if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
                         assert tokens[token_id] == token
                     tokens[token_id] = token
                     scores[token_id] = -1000.0
@@ -2145,7 +2188,7 @@ class InternLM2Model(Model):
                 toktype = SentencePieceTokenTypes.BYTE
             # take care of ununsed raw token
             if piece.startswith('[UNUSED'):
-                toktype = SentencePieceTokenTypes.UNKNOWN
+                toktype = SentencePieceTokenTypes.UNUSED
 
             tokens.append(text)
             scores.append(score)
@@ -2175,7 +2218,7 @@ class InternLM2Model(Model):
                     if token == chat_eos_token:
                         chat_eos_token_id = token_id
                     token = token.encode("utf-8")
-                    if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN:
+                    if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
                         assert(tokens[token_id] == token)
                     tokens[token_id] = token
                     scores[token_id] = -1000.0
@@ -2194,7 +2237,7 @@ class InternLM2Model(Model):
                     if token == chat_eos_token:
                         chat_eos_token_id = token_id
                     token = token.encode("utf-8")
-                    if toktypes[token_id] != SentencePieceTokenTypes.UNKNOWN:
+                    if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
                         assert(tokens[token_id] == token)
                     tokens[token_id] = token
                     scores[token_id] = -1000.0
@@ -2434,19 +2477,7 @@ class Gemma2Model(Model):
     model_arch = gguf.MODEL_ARCH.GEMMA2
 
     def set_vocab(self):
-        tokens, scores, toktypes = self._create_vocab_sentencepiece()
-        # hack: This is required so that we can properly use start/end-of-turn for chat template
-        for i in range(108):
-            # including <unusedX>, <start_of_turn>, <end_of_turn>
-            toktypes[i] = SentencePieceTokenTypes.CONTROL
-        self.gguf_writer.add_tokenizer_model("llama")
-        self.gguf_writer.add_tokenizer_pre("default")
-        self.gguf_writer.add_token_list(tokens)
-        self.gguf_writer.add_token_scores(scores)
-        self.gguf_writer.add_token_types(toktypes)
-
-        special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
-        special_vocab.add_to_gguf(self.gguf_writer)
+        self._set_vocab_sentencepiece()
 
         self.gguf_writer.add_add_space_prefix(False)
 
@@ -2770,7 +2801,7 @@ class ArcticModel(Model):
 
         tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
         scores: list[float] = [-10000.0] * vocab_size
-        toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
+        toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
 
         for token_id in range(tokenizer.vocab_size()):
 
@@ -3025,7 +3056,7 @@ class T5Model(Model):
 
         tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
         scores: list[float] = [-10000.0] * vocab_size
-        toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
+        toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
 
         for token_id in range(tokenizer.vocab_size()):
             piece = tokenizer.IdToPiece(token_id)
@@ -3243,15 +3274,14 @@ class ChatGLMModel(Model):
             if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size():
                 score = tokenizer.tokenizer.sp_model.get_score(token_id)
 
-            if len(piece) == 0:
-                text = f"[PAD{token_id}]".encode("utf-8")
-
             if token_id >= tokenizer.tokenizer.sp_model.vocab_size():
                 if piece in special_tokens:
-                    # show special tokens in prompt
-                    toktype = SentencePieceTokenTypes.USER_DEFINED
+                    toktype = SentencePieceTokenTypes.CONTROL
+                elif len(piece) == 0:
+                    text = f"[PAD{token_id}]".encode("utf-8")
+                    toktype = SentencePieceTokenTypes.UNUSED
                 else:
-                    toktype = SentencePieceTokenTypes.UNKNOWN
+                    toktype = SentencePieceTokenTypes.USER_DEFINED
                 tokens.append(text)
                 scores.append(score)
                 toktypes.append(toktype)
@@ -3340,7 +3370,7 @@ class ChatGLMModel(Model):
         for i in range(vocab_size):
             if i not in reverse_vocab:
                 tokens.append(f"[PAD{i}]")
-                toktypes.append(gguf.TokenType.USER_DEFINED)
+                toktypes.append(gguf.TokenType.UNUSED)
             elif reverse_vocab[i] in added_vocab:
                 tokens.append(reverse_vocab[i])
                 if tokenizer.added_tokens_decoder[i].special:
index 59b76a6d80cdf801740e4056bb6882f746bbbfb2..77d34dca280fc5fb56ffa1b1c17bbe372d98458f 100644 (file)
@@ -5419,6 +5419,7 @@ static void llm_load_vocab(
             } else if (
                 tokenizer_pre == "command-r") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R;
+                vocab.tokenizer_clean_spaces = false;
             } else if (
                 tokenizer_pre == "qwen2") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2;
@@ -5652,7 +5653,7 @@ static void llm_load_vocab(
     // build special tokens cache
     {
         for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
-            if (!(vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL)) {
+            if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
                 vocab.cache_special_tokens.push_back(id);
             }
         }
@@ -15411,17 +15412,6 @@ struct llm_tokenizer_bpe {
                     "[0-9][0-9][0-9]",
                 };
                 break;
-            case LLAMA_VOCAB_PRE_TYPE_MPT:
-                // TODO: MPT pre-tokenization regexes are unknown
-                //       the following are close, but not exact. run the following:
-                //       ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
-                GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
-                regex_exprs = {
-                    "\\s?\\p{L}+",
-                    "\\s?\\p{P}+",
-                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                };
-                break;
             case LLAMA_VOCAB_PRE_TYPE_STARCODER:
             case LLAMA_VOCAB_PRE_TYPE_REFACT:
             case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
@@ -15431,6 +15421,7 @@ struct llm_tokenizer_bpe {
                 };
                 break;
             case LLAMA_VOCAB_PRE_TYPE_GPT2:
+            case LLAMA_VOCAB_PRE_TYPE_MPT:
             case LLAMA_VOCAB_PRE_TYPE_OLMO:
             case LLAMA_VOCAB_PRE_TYPE_JAIS:
                 regex_exprs = {
@@ -15457,8 +15448,8 @@ struct llm_tokenizer_bpe {
                 break;
             case LLAMA_VOCAB_PRE_TYPE_VIKING:
                 regex_exprs = {
-                    "\\p{N}",
                     " ?[^(\\s|.,!?…。,、।۔،)]+",
+                    "\\p{N}",
                 };
                 break;
             default:
@@ -16178,12 +16169,20 @@ struct fragment_buffer_variant {
 
 // #define PRETOKENIZERDEBUG
 
-static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
+static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer, bool parse_special) {
     // for each special token
     for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
         const auto & data = vocab.id_to_token[special_id];
         const auto & special_token = data.text;
 
+        if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) {
+            // Ignore control and unknown tokens when parse_special == false
+            continue;
+            // User-defined tokens are still pre-tokenized before everything else
+            // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726
+            // This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.)
+        }
+
         // for each text fragment
         std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
         while (it != buffer.end()) {
@@ -16296,7 +16295,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
 
     if (!raw_text.empty()) {
         fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
-        if (parse_special) tokenizer_st_partition(vocab, fragment_buffer);
+        tokenizer_st_partition(vocab, fragment_buffer, parse_special);
     }
 
     switch (vocab.type) {
index 1f04b6f34ad7e488fec537a6a888b7d1dd86a460..d3d21331bfd3d1a177d88ac632b67d6d0f5edb08 100644 (file)
@@ -195,7 +195,7 @@ int main(int argc, char **argv) {
     const bool add_special = false;
 
     for (const auto & test_kv : k_tests) {
-        const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special, true);
+        const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special, false);
 
         printf("\n");
         printf("src: '%s'\n", test_kv.first.c_str());
@@ -253,7 +253,7 @@ int main(int argc, char **argv) {
         {
             const auto t_start = ggml_time_us();
 
-            res = llama_tokenize(ctx, text, add_special, true);
+            res = llama_tokenize(ctx, text, add_special, false);
 
             const auto t_end = ggml_time_us();
 
index c50a8ca32f6573f9e71aab27c70bfcd69684f304..9ebe6c89185a32b2767093f208d150015e04d14d 100644 (file)
@@ -20,7 +20,7 @@ from typing import Any, Iterator, cast
 from typing_extensions import Buffer
 
 import cffi
-from transformers import AutoTokenizer
+from transformers import AutoTokenizer, PreTrainedTokenizer
 
 
 logger = logging.getLogger("test-tokenizer-random")
@@ -129,7 +129,7 @@ class Tokenizer:
 class TokenizerGroundtruth (Tokenizer):
 
     def __init__(self, dir_tokenizer: str):
-        self.model = AutoTokenizer.from_pretrained(dir_tokenizer)
+        self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer)
         # guess BOS and EOS
         ids = self.encode("a")
         assert 1 <= len(ids) <= 3
@@ -143,7 +143,7 @@ class TokenizerGroundtruth (Tokenizer):
         self.vocab = list(sorted(self.vocab))
         # tokens and lists
         self.special_tokens = list(self.model.all_special_tokens)
-        self.added_tokens   = list(self.model.added_tokens_encoder)
+        self.added_tokens   = self.model.batch_decode(self.model.added_tokens_encoder.values(), skip_special_tokens=False)
         self.bos_token = self.model.bos_token
         self.eos_token = self.model.eos_token
 
@@ -232,6 +232,7 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
         'a\na',            # bert fail
         '"`',              # falcon
         ' \u2e4e',         # falcon
+        '\n\x0b  ',        # falcon
         'a\xa0\xa0\x00b',  # jina-v2-es
         'one <mask>',      # jina-v2-es  <mask> lstrip=true
         'a </s> b',        # rstrip phi-3