std::unordered_map<token, id> token_to_id;
std::vector<token_data> id_to_token;
- std::unordered_map<token, id> special_tokens_cache;
+ std::vector<id> special_tokens_cache;
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
// build special tokens cache
{
- // TODO: It is unclear (to me) at this point, whether special tokes are guaranteed to be of a deterministic type,
- // and will always be correctly labeled in 'added_tokens.json' etc.
- // The assumption is, since special tokens aren't meant to be exposed to end user, they are designed
- // to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer
- // are special tokens.
- // From testing, this appears to correlate 1:1 with special tokens.
- //
-
- // Counting special tokens and verifying in only one direction
- // is sufficient to detect difference in those two sets.
- //
- uint32_t special_tokens_count_by_type = 0;
- uint32_t special_tokens_count_from_verification = 0;
-
- bool special_tokens_definition_mismatch = false;
-
- for (const auto & t : vocab.token_to_id) {
- const auto & token = t.first;
- const auto & id = t.second;
-
- // Count all non-normal tokens in the vocab while iterating
+ for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
- special_tokens_count_by_type++;
+ vocab.special_tokens_cache.push_back(id);
}
+ }
- // Skip single character tokens
- if (token.length() > 1) {
- bool is_tokenizable = false;
-
- // Split token string representation in two, in all possible ways
- // and check if both halves can be matched to a valid token
- for (unsigned i = 1; i < token.length();) {
- const auto left = token.substr(0, i);
- const auto right = token.substr(i);
-
- // check if we didnt partition in the middle of a utf sequence
- auto utf = utf8_len(left.at(left.length() - 1));
-
- if (utf == 1) {
- if (vocab.token_to_id.find(left) != vocab.token_to_id.end() &&
- vocab.token_to_id.find(right) != vocab.token_to_id.end() ) {
- is_tokenizable = true;
- break;
- }
- i++;
- } else {
- // skip over the rest of multibyte utf sequence
- i += utf - 1;
- }
- }
-
- if (!is_tokenizable) {
- // Some tokens are multibyte, but they are utf sequences with equivalent text length of 1
- // it's faster to re-filter them here, since there are way less candidates now
-
- // Calculate a total "utf" length of a token string representation
- size_t utf8_str_len = 0;
- for (unsigned i = 0; i < token.length();) {
- utf8_str_len++;
- i += utf8_len(token.at(i));
- }
-
- // And skip the ones which are one character
- if (utf8_str_len > 1) {
- // At this point what we have left are special tokens only
- vocab.special_tokens_cache[token] = id;
-
- // Count manually found special tokens
- special_tokens_count_from_verification++;
-
- // If this manually found special token is not marked as such, flag a mismatch
- if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) {
- special_tokens_definition_mismatch = true;
- }
- }
- }
+ std::sort( vocab.special_tokens_cache.begin(), vocab.special_tokens_cache.end(),
+ [&] (const llama_vocab::id a, const llama_vocab::id b) {
+ return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
}
- }
+ );
- if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) {
- LLAMA_LOG_WARN("%s: mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n",
- __func__,
- special_tokens_count_from_verification, vocab.id_to_token.size(),
- special_tokens_count_by_type, vocab.id_to_token.size()
- );
- } else {
- LLAMA_LOG_INFO("%s: special tokens definition check successful ( %u/%zu ).\n",
- __func__,
- special_tokens_count_from_verification, vocab.id_to_token.size()
- );
- }
+ LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.special_tokens_cache.size());
}
}
llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
- auto * token_map = &vocab.token_to_id;
+ const auto & token_map = vocab.token_to_id;
// normalize and split by whitespace
std::vector<std::string> words = preprocess(text);
}
// prepend phantom space
- std::string word1 = "\xe2\x96\x81" + word;
- int n = word1.size();
+ const std::string word1 = "\xe2\x96\x81" + word;
+ const int n = word1.size();
- // we're at the start of a new word
- int i = 0;
- bool match_any = false;
+ const size_t current_tokens = output.size();
+ // we're at the start of a new word
// move through character position in word
- while (i < n) {
+ for (int i = 0; i < n; ++i) {
// loop through possible match length
bool match = false;
for (int j = n; j > i; j--) {
- auto it = token_map->find(word1.substr(i, j - i));
- if (it != token_map->end()) {
+ auto it = token_map.find(word1.substr(i, j - i));
+ if (it != token_map.end()) {
output.push_back(it->second);
match = true;
- match_any = true;
- i = j;
+ i = j - 1;
break;
}
}
- // must be an unknown character
- if (!match) {
- i++;
+ if (!match) { // discard all
+ output.resize(current_tokens);
+ break; // and discard next tokens
}
}
// we didn't find any matches for this word
- if (!match_any) {
+ if (current_tokens == output.size()) {
output.push_back(vocab.special_unk_id);
}
}
}
std::vector<std::string> preprocess(const std::string & text) {
- std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
-
- // strip accents, strip control, uniformize whitespace,
- // to lowercase, pad chinese characters, pad punctuation
- std::string new_str = "";
- for (uint32_t code : cpts_nfd) {
- const codepoint_flags flags = unicode_cpt_flags(code);
- if (flags.is_accent_mark || flags.is_control) {
+ const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
+ std::vector<std::string> words(1, "");
+
+ for (const char32_t cpt : cpts_nfd) {
+ const auto flags = unicode_cpt_flags(cpt);
+
+ if (flags.is_whitespace) {
+ if (words.back().size()) { // finish previous word if any
+ words.emplace_back();
+ }
continue;
}
- code = unicode_tolower(code);
- if (flags.is_separator || flags.is_whitespace) { //####FIXME: is_separator ?
- code = ' ';
- }
- std::string s = unicode_cpt_to_utf8(code);
- if (flags.is_punctuation || is_ascii_punct(code) || is_chinese_char(code)) {
- new_str += " ";
- new_str += s;
- new_str += " ";
- } else {
- new_str += s;
+
+ assert (!flags.is_separator);
+ if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
+ continue;
}
- }
- // split by whitespace
- uint64_t l = 0;
- uint64_t r = 0;
- std::vector<std::string> words;
- while (r < new_str.size()) {
- // if is whitespace
- if (isspace(new_str[r], std::locale::classic())) {
- if (r > l) words.push_back(new_str.substr(l, (r - l)));
- l = r + 1;
- r = l;
+ const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
+ if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
+ if (words.back().size()) { // finish previous word if any
+ words.emplace_back();
+ }
+ words.back() = s; // single char word
+ words.emplace_back(); // start a new word
} else {
- r += 1;
+ words.back() += s; // append char to word
}
}
- if (r > l) {
- words.push_back(new_str.substr(l, (r - l)));
- }
- return words;
- }
- bool is_ascii_punct(uint32_t code) {
- if (code > 0xFF) {
- return false;
+ if (!words.back().size()) {
+ words.pop_back();
}
- auto c = char(static_cast<unsigned char>(code));
- return ispunct(c, std::locale::classic());
+
+ return words;
}
- bool is_chinese_char(uint32_t cpt) {
- if ((cpt >= 0x4E00 && cpt <= 0x9FFF) ||
- (cpt >= 0x3400 && cpt <= 0x4DBF) ||
+ static bool is_chinese_char(uint32_t cpt) {
+ return
+ (cpt >= 0x04E00 && cpt <= 0x09FFF) ||
+ (cpt >= 0x03400 && cpt <= 0x04DBF) ||
(cpt >= 0x20000 && cpt <= 0x2A6DF) ||
(cpt >= 0x2A700 && cpt <= 0x2B73F) ||
(cpt >= 0x2B740 && cpt <= 0x2B81F) ||
(cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
- (cpt >= 0xF900 && cpt <= 0xFAFF) ||
- (cpt >= 0x2F800 && cpt <= 0x2FA1F) ||
- (cpt >= 0x3000 && cpt <= 0x303F) ||
- (cpt >= 0xFF00 && cpt <= 0xFFEF)) {
- return true; // NOLINT
- }
- return false;
+ (cpt >= 0x0F900 && cpt <= 0x0FAFF) ||
+ (cpt >= 0x2F800 && cpt <= 0x2FA1F);
+ //(cpt >= 0x3000 && cpt <= 0x303F) ||
+ //(cpt >= 0xFF00 && cpt <= 0xFFEF);
}
const llama_vocab & vocab;
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
// for each special token
- for (const auto & st: vocab.special_tokens_cache) {
- const auto & special_token = st.first;
- const auto & special_id = st.second;
+ for (const llama_vocab::id special_id : vocab.special_tokens_cache) {
+ const auto & special_token = vocab.id_to_token[special_id].text;
// for each text fragment
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
// if a fragment is text ( not yet processed )
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
- auto * raw_text = &(fragment.raw_text);
+ auto & raw_text = fragment.raw_text;
auto raw_text_base_offset = fragment.offset;
auto raw_text_base_length = fragment.length;
// find the first occurrence of a given special token in this fragment
// passing offset argument only limit the "search area" but match coordinates
// are still relative to the source full raw_text
- auto match = raw_text->find(special_token, raw_text_base_offset);
+ auto match = raw_text.find(special_token, raw_text_base_offset);
// no occurrences found, stop processing this fragment for a given special token
if (match == std::string::npos) break;
// left
const int64_t left_reminder_offset = raw_text_base_offset + 0;
const int64_t left_reminder_length = match - raw_text_base_offset;
- buffer.emplace_after(it, (*raw_text), left_reminder_offset, left_reminder_length);
+ buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
const int64_t right_reminder_offset = match + special_token.length();
const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
- buffer.emplace_after(it, (*raw_text), right_reminder_offset, right_reminder_length);
+ buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());