]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Tokenizer WPM fixes (#7500)
authorjaime-m-p <redacted>
Tue, 28 May 2024 19:46:34 +0000 (21:46 +0200)
committerGitHub <redacted>
Tue, 28 May 2024 19:46:34 +0000 (21:46 +0200)
* Update random test: add_bos_token.
* Update random test: add WPM models for testing.
* Build vocab.special_tokens_cache using vocab token types.
* Fix and improve WPM preprocessing.
  - Fix unicode edge case combinations.
  - Split by whitspace in the same pass.
* Discard all tokens when no matching found.

llama.cpp
tests/test-tokenizer-random.py

index 468a7cb25fa5002cf1cc3510196d5ee0e778d6f6..dac81acc06a92ffab6071f2a17a34d249c345d83 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2162,7 +2162,7 @@ struct llama_vocab {
     std::unordered_map<token, id> token_to_id;
     std::vector<token_data>       id_to_token;
 
-    std::unordered_map<token, id> special_tokens_cache;
+    std::vector<id> special_tokens_cache;
 
     std::map<std::pair<std::string, std::string>, int> bpe_ranks;
 
@@ -4831,97 +4831,19 @@ static void llm_load_vocab(
 
     // build special tokens cache
     {
-        // TODO: It is unclear (to me) at this point, whether special tokes are guaranteed to be of a deterministic type,
-        //  and will always be correctly labeled in 'added_tokens.json' etc.
-        // The assumption is, since special tokens aren't meant to be exposed to end user, they are designed
-        //  to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer
-        //  are special tokens.
-        // From testing, this appears to correlate 1:1 with special tokens.
-        //
-
-        // Counting special tokens and verifying in only one direction
-        //  is sufficient to detect difference in those two sets.
-        //
-        uint32_t special_tokens_count_by_type = 0;
-        uint32_t special_tokens_count_from_verification = 0;
-
-        bool special_tokens_definition_mismatch = false;
-
-        for (const auto & t : vocab.token_to_id) {
-            const auto & token = t.first;
-            const auto & id    = t.second;
-
-            // Count all non-normal tokens in the vocab while iterating
+        for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
             if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
-                special_tokens_count_by_type++;
+                vocab.special_tokens_cache.push_back(id);
             }
+        }
 
-            // Skip single character tokens
-            if (token.length() > 1) {
-                bool is_tokenizable = false;
-
-                // Split token string representation in two, in all possible ways
-                //  and check if both halves can be matched to a valid token
-                for (unsigned i = 1; i < token.length();) {
-                    const auto left  = token.substr(0, i);
-                    const auto right = token.substr(i);
-
-                    // check if we didnt partition in the middle of a utf sequence
-                    auto utf = utf8_len(left.at(left.length() - 1));
-
-                    if (utf == 1) {
-                        if (vocab.token_to_id.find(left)  != vocab.token_to_id.end() &&
-                            vocab.token_to_id.find(right) != vocab.token_to_id.end() ) {
-                            is_tokenizable = true;
-                            break;
-                        }
-                        i++;
-                    } else {
-                        // skip over the rest of multibyte utf sequence
-                        i += utf - 1;
-                    }
-                }
-
-                if (!is_tokenizable) {
-                    // Some tokens are multibyte, but they are utf sequences with equivalent text length of 1
-                    //  it's faster to re-filter them here, since there are way less candidates now
-
-                    // Calculate a total "utf" length of a token string representation
-                    size_t utf8_str_len = 0;
-                    for (unsigned i = 0; i < token.length();) {
-                        utf8_str_len++;
-                        i += utf8_len(token.at(i));
-                    }
-
-                    // And skip the ones which are one character
-                    if (utf8_str_len > 1) {
-                        // At this point what we have left are special tokens only
-                        vocab.special_tokens_cache[token] = id;
-
-                        // Count manually found special tokens
-                        special_tokens_count_from_verification++;
-
-                        // If this manually found special token is not marked as such, flag a mismatch
-                        if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) {
-                            special_tokens_definition_mismatch = true;
-                        }
-                    }
-                }
+        std::sort( vocab.special_tokens_cache.begin(), vocab.special_tokens_cache.end(),
+            [&] (const llama_vocab::id a, const llama_vocab::id b) {
+                return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
             }
-        }
+        );
 
-        if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) {
-            LLAMA_LOG_WARN("%s: mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n",
-                __func__,
-                special_tokens_count_from_verification, vocab.id_to_token.size(),
-                special_tokens_count_by_type, vocab.id_to_token.size()
-            );
-        } else {
-            LLAMA_LOG_INFO("%s: special tokens definition check successful ( %u/%zu ).\n",
-                __func__,
-                special_tokens_count_from_verification, vocab.id_to_token.size()
-            );
-        }
+        LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.special_tokens_cache.size());
     }
 }
 
@@ -13146,7 +13068,7 @@ struct llm_tokenizer_wpm {
     llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
 
     void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
-        auto * token_map = &vocab.token_to_id;
+        const auto & token_map = vocab.token_to_id;
 
         // normalize and split by whitespace
         std::vector<std::string> words = preprocess(text);
@@ -13161,108 +13083,89 @@ struct llm_tokenizer_wpm {
             }
 
             // prepend phantom space
-            std::string word1 = "\xe2\x96\x81" + word;
-            int n = word1.size();
+            const std::string word1 = "\xe2\x96\x81" + word;
+            const int n = word1.size();
 
-            // we're at the start of a new word
-            int i = 0;
-            bool match_any = false;
+            const size_t current_tokens = output.size();
 
+            // we're at the start of a new word
             // move through character position in word
-            while (i < n) {
+            for (int i = 0; i < n; ++i) {
                 // loop through possible match length
                 bool match = false;
                 for (int j = n; j > i; j--) {
-                    auto it = token_map->find(word1.substr(i, j - i));
-                    if (it != token_map->end()) {
+                    auto it = token_map.find(word1.substr(i, j - i));
+                    if (it != token_map.end()) {
                         output.push_back(it->second);
                         match = true;
-                        match_any = true;
-                        i = j;
+                        i = j - 1;
                         break;
                     }
                 }
 
-                // must be an unknown character
-                if (!match) {
-                    i++;
+                if (!match) { // discard all
+                    output.resize(current_tokens);
+                    break;  // and discard next tokens
                 }
             }
 
             // we didn't find any matches for this word
-            if (!match_any) {
+            if (current_tokens == output.size()) {
                 output.push_back(vocab.special_unk_id);
             }
         }
     }
 
     std::vector<std::string> preprocess(const std::string & text) {
-        std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
-
-        // strip accents, strip control, uniformize whitespace,
-        // to lowercase, pad chinese characters, pad punctuation
-        std::string new_str = "";
-        for (uint32_t code : cpts_nfd) {
-            const codepoint_flags flags = unicode_cpt_flags(code);
-            if (flags.is_accent_mark || flags.is_control) {
+        const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
+        std::vector<std::string> words(1, "");
+
+        for (const char32_t cpt : cpts_nfd) {
+            const auto flags = unicode_cpt_flags(cpt);
+
+            if (flags.is_whitespace) {
+                if (words.back().size()) {  // finish previous word if any
+                    words.emplace_back();
+                }
                 continue;
             }
-            code = unicode_tolower(code);
-            if (flags.is_separator || flags.is_whitespace) {  //####FIXME: is_separator ?
-                code = ' ';
-            }
-            std::string s = unicode_cpt_to_utf8(code);
-            if (flags.is_punctuation || is_ascii_punct(code) || is_chinese_char(code)) {
-                new_str += " ";
-                new_str += s;
-                new_str += " ";
-            } else {
-                new_str += s;
+
+            assert (!flags.is_separator);
+            if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
+                continue;
             }
-        }
 
-        // split by whitespace
-        uint64_t l = 0;
-        uint64_t r = 0;
-        std::vector<std::string> words;
-        while (r < new_str.size()) {
-            // if is whitespace
-            if (isspace(new_str[r], std::locale::classic())) {
-                if (r > l) words.push_back(new_str.substr(l, (r - l)));
-                l = r + 1;
-                r = l;
+            const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
+            if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
+                if (words.back().size()) {  // finish previous word if any
+                    words.emplace_back();
+                }
+                words.back() = s;       // single char word
+                words.emplace_back();   // start a new word
             } else {
-                r += 1;
+                words.back() += s;  // append char to word
             }
         }
-        if (r > l) {
-            words.push_back(new_str.substr(l, (r - l)));
-        }
-        return words;
-    }
 
-    bool is_ascii_punct(uint32_t code) {
-        if (code > 0xFF) {
-            return false;
+        if (!words.back().size()) {
+            words.pop_back();
         }
-        auto c = char(static_cast<unsigned char>(code));
-        return ispunct(c, std::locale::classic());
+
+        return words;
     }
 
-    bool is_chinese_char(uint32_t cpt) {
-        if ((cpt >= 0x4E00  && cpt <= 0x9FFF)  ||
-            (cpt >= 0x3400  && cpt <= 0x4DBF)  ||
+    static bool is_chinese_char(uint32_t cpt) {
+        return
+            (cpt >= 0x04E00 && cpt <= 0x09FFF) ||
+            (cpt >= 0x03400 && cpt <= 0x04DBF) ||
             (cpt >= 0x20000 && cpt <= 0x2A6DF) ||
             (cpt >= 0x2A700 && cpt <= 0x2B73F) ||
             (cpt >= 0x2B740 && cpt <= 0x2B81F) ||
             (cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
-            (cpt >= 0xF900  && cpt <= 0xFAFF)  ||
-            (cpt >= 0x2F800 && cpt <= 0x2FA1F) ||
-            (cpt >= 0x3000  && cpt <= 0x303F)  ||
-            (cpt >= 0xFF00  && cpt <= 0xFFEF)) {
-            return true; // NOLINT
-        }
-        return false;
+            (cpt >= 0x0F900 && cpt <= 0x0FAFF) ||
+            (cpt >= 0x2F800 && cpt <= 0x2FA1F);
+            //(cpt >= 0x3000  && cpt <= 0x303F)  ||
+            //(cpt >= 0xFF00  && cpt <= 0xFFEF);
     }
 
     const llama_vocab & vocab;
@@ -13306,9 +13209,8 @@ struct fragment_buffer_variant {
 
 static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
     // for each special token
-    for (const auto & st: vocab.special_tokens_cache) {
-        const auto & special_token = st.first;
-        const auto & special_id    = st.second;
+    for (const llama_vocab::id special_id : vocab.special_tokens_cache) {
+        const auto & special_token = vocab.id_to_token[special_id].text;
 
         // for each text fragment
         std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
@@ -13317,7 +13219,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
 
             // if a fragment is text ( not yet processed )
             if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                auto * raw_text = &(fragment.raw_text);
+                auto & raw_text = fragment.raw_text;
 
                 auto raw_text_base_offset = fragment.offset;
                 auto raw_text_base_length = fragment.length;
@@ -13327,7 +13229,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                     // find the first occurrence of a given special token in this fragment
                     //  passing offset argument only limit the "search area" but match coordinates
                     //  are still relative to the source full raw_text
-                    auto match = raw_text->find(special_token, raw_text_base_offset);
+                    auto match = raw_text.find(special_token, raw_text_base_offset);
 
                     // no occurrences found, stop processing this fragment for a given special token
                     if (match == std::string::npos) break;
@@ -13346,7 +13248,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                         // left
                         const int64_t left_reminder_offset = raw_text_base_offset + 0;
                         const int64_t left_reminder_length = match - raw_text_base_offset;
-                        buffer.emplace_after(it, (*raw_text), left_reminder_offset, left_reminder_length);
+                        buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
 
 #ifdef PRETOKENIZERDEBUG
                         LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
@@ -13362,7 +13264,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                     if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
                         const int64_t right_reminder_offset = match + special_token.length();
                         const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
-                        buffer.emplace_after(it, (*raw_text), right_reminder_offset, right_reminder_length);
+                        buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
 
 #ifdef PRETOKENIZERDEBUG
                         LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
index 7e1b656e5f5fc3834505ebae64d5cb4af70cfa94..ec1b2837cfab5c697e53cd6ec10628d00469f6a8 100644 (file)
@@ -167,8 +167,10 @@ def generator_random_special_tokens(tokenizer, iterations=100) -> Iterator[str]:
     for m in range(iterations):
         rand.seed(m)
         words = rand.choices(special_tokens, k=500)
-        if tokenizer.add_bos_token:  # skip spam warning of double BOS
-            while words and words[0] == tokenizer.bos_token:
+        if words[0] == tokenizer.bos_token:  # skip spam warning of double BOS
+            while len(words) > 1 and words[1] == tokenizer.bos_token:  # leave one starting BOS
+                words.pop(0)
+            if tokenizer.add_bos_token:  # drop all starting BOS
                 words.pop(0)
         yield "".join(words)
 
@@ -293,15 +295,17 @@ def main(argv: list[str] = None):
     model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
     tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
 
-    tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", True)
-    tokenizer.add_eos_token = getattr(tokenizer, "add_eos_token", False)
-
     def func_tokenize1(text: str):
         return model.tokenize(text, add_special=True, parse_special=True)
 
     def func_tokenize2(text: str):
         return tokenizer.encode(text, add_special_tokens=True)
 
+    ids = func_tokenize2("a")
+    assert 1 <= len(ids) <= 3
+    add_bos_token = len(ids) > 1 and tokenizer.bos_token_id == ids[0]
+    tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", add_bos_token)
+
     vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
     test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())
     test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
@@ -324,8 +328,10 @@ if __name__ == "__main__":
     # import os
     # tokenizers = os.listdir(path_tokenizers)
     tokenizers = [
-        "llama-spm",   # SPM
-        "phi-3",       # SPM
+        # "llama-spm",   # SPM
+        # "phi-3",       # SPM
+        "jina-v2-en",  # WPM
+        "bert-bge",    # WPM
     ]
 
     for tokenizer in tokenizers: