]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : optimize long word tokenization with WPM (#8034)
authorGeorgi Gerganov <redacted>
Fri, 21 Jun 2024 05:51:28 +0000 (08:51 +0300)
committerGitHub <redacted>
Fri, 21 Jun 2024 05:51:28 +0000 (08:51 +0300)
ggml-ci

llama.cpp
unicode.cpp

index 9ca0b7479619d6b22a6b4764aefa2de09846ee74..a05a52b4234cd0a6afd48a654c37de3497d92d9d 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2293,6 +2293,8 @@ struct llama_vocab {
     enum llama_vocab_type     type     = LLAMA_VOCAB_TYPE_SPM;
     enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
 
+    int max_token_len = 0; // used for optimizing longest token search
+
     std::unordered_map<token, id> token_to_id;
     std::vector<token_data>       id_to_token;
 
@@ -4939,6 +4941,7 @@ static void llm_load_vocab(
         GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
 
         vocab.token_to_id[word] = i;
+        vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size());
 
         auto & token_data = vocab.id_to_token[i];
         token_data.text  = std::move(word);
@@ -5249,6 +5252,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token        = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
     if (vocab.special_eot_id    != -1) { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, vocab.special_eot_id,    vocab.id_to_token[vocab.special_eot_id].text.c_str() );    }
 
+    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
+
     if (model.arch == LLM_ARCH_DEEPSEEK2) {
         LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
         LLAMA_LOG_INFO("%s: n_lora_q             = %d\n",     __func__, hparams.n_lora_q);
@@ -13488,7 +13493,7 @@ private:
 struct llm_tokenizer_wpm {
     llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
 
-    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
+    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) const {
         const auto & token_map = vocab.token_to_id;
 
         // normalize and split by whitespace
@@ -13497,7 +13502,7 @@ struct llm_tokenizer_wpm {
         // bos token prepended already
 
         // find the longest tokens that form the words
-        for (const std::string &word : words) {
+        for (const std::string & word : words) {
             // skip empty words
             if (word.size() == 0) {
                 continue;
@@ -13514,7 +13519,7 @@ struct llm_tokenizer_wpm {
             for (int i = 0; i < n; ++i) {
                 // loop through possible match length
                 bool match = false;
-                for (int j = n; j > i; j--) {
+                for (int j = std::min(n, i + vocab.max_token_len + 1); j > i; j--) {
                     auto it = token_map.find(word1.substr(i, j - i));
                     if (it != token_map.end()) {
                         output.push_back(it->second);
@@ -13537,7 +13542,8 @@ struct llm_tokenizer_wpm {
         }
     }
 
-    std::vector<std::string> preprocess(const std::string & text) {
+    // TODO: reduce string copies by using cpts_offs array
+    std::vector<std::string> preprocess(const std::string & text) const {
         const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
         std::vector<std::string> words(1, "");
 
@@ -13832,6 +13838,8 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
                     output.push_back(vocab.special_cls_id);
                 }
 
+                llm_tokenizer_wpm tokenizer(vocab);
+
                 for (const auto & fragment : fragment_buffer) {
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
                         auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
@@ -13839,7 +13847,6 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
 #ifdef PRETOKENIZERDEBUG
                         LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
 #endif
-                        llm_tokenizer_wpm tokenizer(vocab);
                         tokenizer.tokenize(raw_text, output);
                     } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
                         output.push_back(fragment.token);
index 913c34b9b7bd6ca81804519d396ce5dc84f87509..c0b76bf20aedef4a9d3385d39e7ecb7ed84351a5 100644 (file)
@@ -596,6 +596,7 @@ std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & c
 
 std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
     std::vector<uint32_t> result;
+    result.reserve(utf8.size());
     size_t offset = 0;
     while (offset < utf8.size()) {
         result.push_back(unicode_cpt_from_utf8(utf8, offset));