]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vocab : refactor tokenizer to reduce init overhead (#9449)
authorZhenwei Jin <redacted>
Sat, 28 Sep 2024 12:10:58 +0000 (20:10 +0800)
committerGitHub <redacted>
Sat, 28 Sep 2024 12:10:58 +0000 (15:10 +0300)
* refactor tokenizer

* llama : make llm_tokenizer more private

ggml-ci

* refactor tokenizer

* refactor tokenizer

* llama : make llm_tokenizer more private

ggml-ci

* remove unused files

* remove unused fileds to avoid unused filed build error

* avoid symbol link error

* Update src/llama.cpp

* Update src/llama.cpp

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp
src/llama-vocab.cpp
src/llama-vocab.h
src/llama.cpp
tests/test-tokenizer-0.cpp

index ecff95f9a69de438d5ad32954394d26391f3c172..c140daed3c056b15d464af514f23a02e3fef66cc 100644 (file)
@@ -201,7 +201,7 @@ static void print_sample_weights(TransformerWeights *w){
 
 //////////////////////////////////////// ggml structs and functions required to load models, configs and save the model.
 
-struct llama_vocab {
+struct my_llama_vocab {
     using id    = int32_t;
     using token = std::string;
     using ttype = llama_token_type;
@@ -525,7 +525,7 @@ static std::string llama_escape_whitespaces(const std::string & text) {
     return out.str();
 }
 
-static void load_vocab(const char * filename, const Config * config, struct llama_vocab * vocab) {
+static void load_vocab(const char * filename, const Config * config, struct my_llama_vocab * vocab) {
     if (is_ggml_file(filename)) {
         LOG_INF("%s: Loading vocabulary from gguf file %s\n", __func__, filename);
         struct ggml_context * ctx_data = NULL;
@@ -583,13 +583,13 @@ static void load_vocab(const char * filename, const Config * config, struct llam
         const int  n_vocab = config->vocab_size;
         /* uint32_t max_token_length =  */ file.read_u32(); // unused
         vocab->id_to_token.resize(n_vocab);
-        for (llama_vocab::id id=0; id<n_vocab; ++id) {
+        for (my_llama_vocab::id id=0; id<n_vocab; ++id) {
             float_t score = file.read_f32();
             uint32_t len = file.read_u32();
             std::string text = file.read_string(len);
 
             unsigned char byte_val;
-            llama_vocab::ttype type = LLAMA_TOKEN_TYPE_NORMAL;
+            my_llama_vocab::ttype type = LLAMA_TOKEN_TYPE_NORMAL;
             if (id == UNKNOWN_TOKEN_ID) {
                 text = "<unk>";
                 type = LLAMA_TOKEN_TYPE_UNKNOWN;
@@ -631,7 +631,7 @@ static void convert_weights_ak_to_gg(struct ggml_tensor * gg_weights, const floa
 }
 
 static void save_as_llama_model(
-    struct llama_vocab * vocab, struct my_llama_model * model, TransformerWeights* w, const char * filename
+    struct my_llama_vocab * vocab, struct my_llama_model * model, TransformerWeights* w, const char * filename
 ) {
     // convert AK weights into GG weights one by one.
     // w->token_embedding_table -> model->tok_embeddings
@@ -671,7 +671,7 @@ static void save_as_llama_model(
     std::vector<const char*> tokens;
     std::vector<float> scores;
     std::vector<llama_token_type> token_types;
-    for (const llama_vocab::token_data & token_data : vocab->id_to_token) {
+    for (const my_llama_vocab::token_data & token_data : vocab->id_to_token) {
         tokens.push_back(token_data.text.c_str());
         scores.push_back(token_data.score);
         token_types.push_back(token_data.type);
@@ -905,7 +905,7 @@ int main(int argc, char ** argv) {
         fclose(file);
     }
 
-    struct llama_vocab vocab;
+    struct my_llama_vocab vocab;
     load_vocab(params.fn_vocab_model, &config, &vocab);
 
     struct my_llama_model model;
index 146d416f770f275a0bdabac128c7d89521a030b3..e4d844a73c216c67db9ebb6b30e5b96f68c1440d 100644 (file)
@@ -50,7 +50,7 @@ struct naive_trie {
             res.first->second.insert(key + 1, len - 1, value);
         }
     }
-    std::pair<const char *, size_t> get_longest_prefix(const char * key, size_t len, size_t offset = 0) {
+    std::pair<const char *, size_t> get_longest_prefix(const char * key, size_t len, size_t offset = 0) const {
         if (len == 0 || offset == len) {
             return std::make_pair(key, offset);
         }
@@ -79,6 +79,15 @@ struct naive_trie {
 // impl
 //
 
+struct llm_tokenizer {
+   llm_tokenizer() {}
+   virtual ~llm_tokenizer() = default;
+};
+
+llama_vocab::~llama_vocab() {
+    delete tokenizer;
+}
+
 int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
     GGML_ASSERT(token_left.find(' ')   == std::string::npos);
     GGML_ASSERT(token_left.find('\n')  == std::string::npos);
@@ -187,10 +196,15 @@ struct llm_bigram_spm {
     size_t size;
 };
 
-struct llm_tokenizer_spm {
-    llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {}
+struct llm_tokenizer_spm : llm_tokenizer {
+    llm_tokenizer_spm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
+};
+
+struct llm_tokenizer_spm_session {
+    llm_tokenizer_spm_session(const llama_vocab & vocab) : vocab(vocab) {}
 
     void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
+
         // split string into utf8 chars
         int index = 0;
         size_t offs = 0;
@@ -271,7 +285,7 @@ private:
             return;
         }
 
-        resegment(symbols[p->second.first],  output);
+        resegment(symbols[p->second.first], output);
         resegment(symbols[p->second.second], output);
     }
 
@@ -279,7 +293,6 @@ private:
         if (left == -1 || right == -1) {
             return;
         }
-
         const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
         auto token = vocab.token_to_id.find(text);
 
@@ -306,10 +319,11 @@ private:
     }
 
     const llama_vocab & vocab;
+    // currently unused
+    // const llm_tokenizer_spm * spm_tokenizer;
 
     std::vector<llm_symbol> symbols;
     llm_bigram_spm::queue work_queue;
-
     std::map<std::string, std::pair<int, int>> rev_merge;
 };
 
@@ -352,8 +366,8 @@ struct llm_bigram_bpe {
     size_t size;
 };
 
-struct llm_tokenizer_bpe {
-    llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
+struct llm_tokenizer_bpe : llm_tokenizer {
+    llm_tokenizer_bpe(const llama_vocab & vocab) : llm_tokenizer() {
         GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
         switch (vocab.type_pre) {
             case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
@@ -476,7 +490,14 @@ struct llm_tokenizer_bpe {
         }
     }
 
-    void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) const {
+    std::vector<std::string> regex_exprs;
+};
+
+struct llm_tokenizer_bpe_session {
+    llm_tokenizer_bpe_session(const llama_vocab & vocab) : vocab(vocab),
+        bpe_tokenizer(static_cast<const llm_tokenizer_bpe *>(vocab.tokenizer)) {}
+
+    static void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output)  {
         output.push_back(token_id);
     }
 
@@ -515,12 +536,11 @@ struct llm_tokenizer_bpe {
 
     void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
         int final_prev_index = -1;
-
-        const auto word_collection = unicode_regex_split(text, regex_exprs);
+        const auto word_collection = unicode_regex_split(text, bpe_tokenizer->regex_exprs);
 
         symbols_final.clear();
 
-        for (auto & word : word_collection) {
+        for (const auto & word : word_collection) {
             work_queue = llm_bigram_bpe::queue();
             symbols.clear();
 
@@ -623,7 +643,6 @@ private:
         if (left == -1 || right == -1) {
             return;
         }
-
         std::string left_token  = std::string(symbols[left].text,  symbols[left].n);
         std::string right_token = std::string(symbols[right].text, symbols[right].n);
 
@@ -647,12 +666,10 @@ private:
     }
 
     const llama_vocab & vocab;
-
-    std::vector<std::string> regex_exprs;
+    const llm_tokenizer_bpe * bpe_tokenizer;
 
     std::vector<llm_symbol> symbols;
     std::vector<llm_symbol> symbols_final;
-
     llm_bigram_bpe::queue work_queue;
 };
 
@@ -660,15 +677,17 @@ private:
 // WPM tokenizer
 //
 
-struct llm_tokenizer_wpm {
-    llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
+struct llm_tokenizer_wpm : llm_tokenizer {
+    llm_tokenizer_wpm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
+};
 
-    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) const {
-        const auto & token_map = vocab.token_to_id;
+struct llm_tokenizer_wpm_session {
+    llm_tokenizer_wpm_session(const llama_vocab & vocab) : vocab(vocab) {}
 
+    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
+        const auto & token_map = vocab.token_to_id;
         // normalize and split by whitespace
         std::vector<std::string> words = preprocess(text);
-
         // bos token prepended already
 
         // find the longest tokens that form the words
@@ -713,7 +732,7 @@ struct llm_tokenizer_wpm {
     }
 
     // TODO: reduce string copies by using cpts_offs array
-    std::vector<std::string> preprocess(const std::string & text) const {
+    static std::vector<std::string> preprocess(const std::string & text)  {
         const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
         std::vector<std::string> words(1, "");
 
@@ -765,15 +784,18 @@ struct llm_tokenizer_wpm {
             //(cpt >= 0xFF00  && cpt <= 0xFFEF);
     }
 
+private:
     const llama_vocab & vocab;
+    // currently unused
+    // const llm_tokenizer_wpm * wpm_tokenizer;
 };
 
 //
 // UGM tokenizer
 //
 
-struct llm_tokenizer_ugm {
-    llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) {
+struct llm_tokenizer_ugm : llm_tokenizer {
+    llm_tokenizer_ugm(const llama_vocab & vocab) : llm_tokenizer() {
         if (vocab.precompiled_charsmap.size() > 0) {
             size_t charsmap_offset = 0;
 
@@ -819,6 +841,30 @@ struct llm_tokenizer_ugm {
         unknown_token_score = min_score - unknown_token_score_penalty;
     }
 
+    // escaped space symbol - U+2581 (Lower One Eighth Block)
+    const std::string escaped_space = "\xE2\x96\x81";
+
+    const char * prefix_replacements = NULL;
+    size_t prefix_replacements_size = 0;
+
+    const uint32_t * xcda_array = NULL;
+    size_t xcda_array_size = 0;
+
+    struct naive_trie user_defined_token_matcher;
+
+    float min_score = FLT_MAX;
+    float max_score = -FLT_MAX;
+
+    float unknown_token_score_penalty = 10.0;
+    float unknown_token_score;
+
+    struct naive_trie token_matcher;
+};
+
+struct llm_tokenizer_ugm_session {
+    llm_tokenizer_ugm_session(const llama_vocab & vocab) : vocab(vocab),
+        ugm_tokenizer(static_cast<const llm_tokenizer_ugm *>(vocab.tokenizer)) {}
+
     /* This implementation is based on SentencePiece optimized Viterbi algorithm for
      * unigram language models. The general idea is to:
      * - move along the input sequence in steps of one UTF code point,
@@ -857,7 +903,7 @@ struct llm_tokenizer_ugm {
             // traverse the token matcher trie to find a matching token
             bool single_codepoint_token_found = false;
             const struct best_tokenization & current_best = tokenization_results[input_offset];
-            const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
+            const struct naive_trie * node = ugm_tokenizer->token_matcher.traverse(normalized[prefix_offset++]);
 
             while (prefix_offset <= input_len && node != NULL) {
                 // check if we found valid token in prefix
@@ -887,7 +933,7 @@ struct llm_tokenizer_ugm {
             // if we didn't find a valid token corresponding to the whole UTF code point
             // then use unknown token as the tokenization of this UTF code point
             if (!single_codepoint_token_found) {
-                const double challenger_score = current_best.score_sum + unknown_token_score;
+                const double challenger_score = current_best.score_sum + ugm_tokenizer->unknown_token_score;
                 prefix_offset = input_offset + n_utf8_code_units;
                 struct best_tokenization & current_champ = tokenization_results[prefix_offset];
                 if (challenger_score > current_champ.score_sum) {
@@ -919,7 +965,6 @@ struct llm_tokenizer_ugm {
     }
 
 private:
-    const llama_vocab & vocab;
 
     // helper structure for returning normalization results
     struct normalization_result {
@@ -932,7 +977,7 @@ private:
         normalized->clear();
         normalized->reserve(input.size() * 3);
 
-        const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " ";
+        const std::string space = vocab.tokenizer_escape_whitespaces ? ugm_tokenizer->escaped_space : " ";
 
         bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
         bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
@@ -1014,13 +1059,21 @@ private:
         size_t xcda_array_size;
     };
 
+    // this structure stores the best tokenization so far at input_offset
+    struct best_tokenization {
+        llama_token token_id;
+        size_t input_offset;
+        float score_sum;
+    };
+
     struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
         if (input_offset == input.size()) {
             return { &input[input_offset], 0, 0 };
         }
 
         // if input prefix matches some user-defined token return this token as normalization result
-        auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
+        auto user_defined_token_match =
+           ugm_tokenizer->user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
         if (user_defined_token_match.second > 0) {
             return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second };
         }
@@ -1028,8 +1081,8 @@ private:
         size_t longest_prefix_length = 0;
         size_t longest_prefix_offset = 0;
 
-        if (xcda_array_size > 0) {
-            struct xcda_array_view xcda_view(xcda_array, xcda_array_size);
+        if (ugm_tokenizer->xcda_array_size > 0) {
+            struct xcda_array_view xcda_view(ugm_tokenizer->xcda_array, ugm_tokenizer->xcda_array_size);
 
             // Find the longest normalized sequence matching the input prefix by walking
             // the XOR-compressed compact double array (XCDA) starting from the root node
@@ -1065,50 +1118,27 @@ private:
 
         if (longest_prefix_length > 0) {
             // we have a match, so return the replacement sequence
-            if (longest_prefix_offset >= prefix_replacements_size) {
+            if (longest_prefix_offset >= ugm_tokenizer->prefix_replacements_size) {
                 throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
             }
-            const char * prefix_replacement = &prefix_replacements[longest_prefix_offset];
+            const char * prefix_replacement = &(ugm_tokenizer->prefix_replacements)[longest_prefix_offset];
             return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length };
-        } else {
-            // check if the input prefix contains a valid sequence of UTF-8 code units
-            try {
-                // if yes, return this sequence unmodified
-                size_t prefix_offset = input_offset;
-                unicode_cpt_from_utf8(input, prefix_offset);
-                return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset };
-            } catch (std::invalid_argument & /*ex*/) {
-                // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER
-                return { "\xEF\xBF\xBD", 3, 1 };
-            }
         }
-    }
-
-    // escaped space symbol - U+2581 (Lower One Eighth Block)
-    const std::string escaped_space = "\xE2\x96\x81";
-
-    const char * prefix_replacements = NULL;
-    size_t prefix_replacements_size = 0;
-
-    const uint32_t * xcda_array = NULL;
-    size_t xcda_array_size = 0;
 
-    struct naive_trie user_defined_token_matcher;
-
-    // this structure stores the best tokenization so far at input_offset
-    struct best_tokenization {
-        llama_token token_id;
-        size_t input_offset;
-        float score_sum;
-    };
-
-    float min_score = FLT_MAX;
-    float max_score = -FLT_MAX;
-
-    float unknown_token_score_penalty = 10.0;
-    float unknown_token_score;
+        // check if the input prefix contains a valid sequence of UTF-8 code units
+        try {
+            // if yes, return this sequence unmodified
+            size_t prefix_offset = input_offset;
+            unicode_cpt_from_utf8(input, prefix_offset);
+            return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset };
+        } catch (std::invalid_argument & /*ex*/) {
+            // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER
+            return { "\xEF\xBF\xBD", 3, 1 };
+        }
+    }
 
-    struct naive_trie token_matcher;
+    const llama_vocab & vocab;
+    const llm_tokenizer_ugm * ugm_tokenizer;
 };
 
 //
@@ -1169,8 +1199,8 @@ static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escape
     return output;
 }
 
-struct llm_tokenizer_rwkv {
-    llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
+struct llm_tokenizer_rwkv : llm_tokenizer {
+    llm_tokenizer_rwkv(const llama_vocab & vocab) : llm_tokenizer() {
         // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
         // For now, we decode the vocab here into the lookup we'll use for tokenization.
 
@@ -1182,11 +1212,17 @@ struct llm_tokenizer_rwkv {
         }
     }
 
+    struct naive_trie token_matcher;
+};
+
+struct llm_tokenizer_rwkv_session {
+    llm_tokenizer_rwkv_session(const llama_vocab & vocab) : vocab(vocab),
+        rwkv_tokenizer(static_cast<const llm_tokenizer_rwkv &>(*vocab.tokenizer)) {}
+
     void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
         uint32_t position = 0;
-
         while (position < text.size()) {
-            const struct naive_trie * node = token_matcher.traverse(text[position]);
+            const struct naive_trie * node = rwkv_tokenizer.token_matcher.traverse(text[position]);
             if (node == NULL) {
                 // no matching token found, add unknown token
                 output.push_back(vocab.special_unk_id);
@@ -1211,11 +1247,33 @@ struct llm_tokenizer_rwkv {
         }
     }
 
+private:
     const llama_vocab & vocab;
-
-    struct naive_trie token_matcher;
+    const llm_tokenizer_rwkv & rwkv_tokenizer;
 };
 
+void llama_vocab::init_tokenizer() {
+    switch (type) {
+        case LLAMA_VOCAB_TYPE_SPM:
+            tokenizer = new llm_tokenizer_spm(*this);
+            break;
+        case LLAMA_VOCAB_TYPE_BPE:
+            tokenizer = new llm_tokenizer_bpe(*this);
+            break;
+        case LLAMA_VOCAB_TYPE_WPM:
+            tokenizer = new llm_tokenizer_wpm(*this);
+            break;
+        case LLAMA_VOCAB_TYPE_UGM:
+            tokenizer = new llm_tokenizer_ugm(*this);
+            break;
+        case LLAMA_VOCAB_TYPE_RWKV:
+            tokenizer = new llm_tokenizer_rwkv(*this);
+            break;
+        default:
+            GGML_ABORT("unsupported vocab type");
+    }
+}
+
 //
 // (de-) tokenize
 //
@@ -1277,7 +1335,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
 
             // if a fragment is text ( not yet processed )
             if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                auto & raw_text = fragment.raw_text;
+                const auto & raw_text = fragment.raw_text;
 
                 auto raw_text_base_offset = fragment.offset;
                 auto raw_text_base_length = fragment.length;
@@ -1376,7 +1434,13 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
     }
 }
 
-std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) {
+std::vector<llama_vocab::id> llama_tokenize_internal(
+        const llama_vocab & vocab,
+        std::string raw_text,
+        bool add_special,
+        bool parse_special) {
+    GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+
     std::vector<llama_vocab::id> output;
     std::forward_list<fragment_buffer_variant> fragment_buffer;
 
@@ -1413,9 +1477,9 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
 #ifdef PRETOKENIZERDEBUG
                         LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
 #endif
-                        llm_tokenizer_spm tokenizer(vocab);
                         llama_escape_whitespace(raw_text);
-                        tokenizer.tokenize(raw_text, output);
+                        llm_tokenizer_spm_session session(vocab);
+                        session.tokenize(raw_text, output);
                         is_prev_special = false;
                     } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
                         output.push_back(fragment.token);
@@ -1437,10 +1501,11 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
             } break;
         case LLAMA_VOCAB_TYPE_BPE:
             {
-                llm_tokenizer_bpe tokenizer(vocab);
-
+                llm_tokenizer_bpe_session session(vocab);
+                // it calls some other methods that are not exist in llm_tokenizer,
+                // here just cast it to bpe tokenizer object
                 if (add_special) {
-                    tokenizer.append_bos(output);
+                    session.append_bos(output);
                 }
                 for (const auto & fragment : fragment_buffer) {
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@@ -1449,15 +1514,15 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
 #ifdef PRETOKENIZERDEBUG
                         LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
 #endif
-                        tokenizer.tokenize(raw_text, output);
+                        session.tokenize(raw_text, output);
                     } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        tokenizer.append(fragment.token, output);
+                        session.append(fragment.token, output);
                     }
                 }
 
                 if (add_special) {
-                    tokenizer.append_eos(output);
-                    tokenizer.check_double_bos_eos(output);
+                    session.append_eos(output);
+                    session.check_double_bos_eos(output);
                 }
             } break;
         case LLAMA_VOCAB_TYPE_WPM:
@@ -1467,7 +1532,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
                     output.push_back(vocab.special_cls_id);
                 }
 
-                llm_tokenizer_wpm tokenizer(vocab);
+                llm_tokenizer_wpm_session session(vocab);
 
                 for (const auto & fragment : fragment_buffer) {
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@@ -1476,7 +1541,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
 #ifdef PRETOKENIZERDEBUG
                         LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
 #endif
-                        tokenizer.tokenize(raw_text, output);
+                        session.tokenize(raw_text, output);
                     } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
                         output.push_back(fragment.token);
                     }
@@ -1489,12 +1554,11 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
             } break;
         case LLAMA_VOCAB_TYPE_UGM:
             {
-                llm_tokenizer_ugm tokenizer(vocab);
-
                 if (add_special && vocab.tokenizer_add_bos != 0) {
                     GGML_ASSERT(vocab.special_bos_id != -1);
                     output.push_back(vocab.special_bos_id);
                 }
+                llm_tokenizer_ugm_session session(vocab);
 
                 for (const auto & fragment : fragment_buffer) {
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@@ -1502,7 +1566,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
 #ifdef PRETOKENIZERDEBUG
                         LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
 #endif
-                        tokenizer.tokenize(raw_text, output);
+                        session.tokenize(raw_text, output);
                     } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
                         output.push_back(fragment.token);
                     }
@@ -1522,6 +1586,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
             } break;
         case LLAMA_VOCAB_TYPE_RWKV:
             {
+                llm_tokenizer_rwkv_session session(vocab);
                 for (const auto & fragment : fragment_buffer) {
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
                         auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
@@ -1530,8 +1595,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
                         LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
 #endif
 
-                        llm_tokenizer_rwkv tokenizer(vocab);
-                        tokenizer.tokenize(raw_text, output);
+                        session.tokenize(raw_text, output);
                     } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
                         output.push_back(fragment.token);
                     }
@@ -1644,13 +1708,13 @@ llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
 }
 
 int32_t llama_tokenize_impl(
-    const struct llama_vocab & vocab,
-                  const char * text,
-                     int32_t   text_len,
-                 llama_token * tokens,
-                     int32_t   n_tokens_max,
-                        bool   add_special,
-                        bool   parse_special) {
+        const struct llama_vocab & vocab,
+                      const char * text,
+                         int32_t   text_len,
+                     llama_token * tokens,
+                         int32_t   n_tokens_max,
+                            bool   add_special,
+                            bool   parse_special) {
     auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special);
     if (n_tokens_max < (int) res.size()) {
         // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
@@ -1775,6 +1839,8 @@ int32_t llama_detokenize_impl(
                          int32_t   text_len_max,
                             bool   remove_special,
                             bool   unparse_special) {
+    GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+
     int32_t avail = text_len_max;
     int32_t total = 0;
 
index cc46f642bf1ae371ce5fa2b23aa9ed3b44f5895d..069bdc423a60baedb7ca7d89a05ce71324982814 100644 (file)
@@ -8,6 +8,8 @@
 #include <map>
 #include <set>
 
+struct llm_tokenizer;
+
 struct llama_vocab {
     using id    = llama_token;
     using token = std::string;
@@ -65,7 +67,14 @@ struct llama_vocab {
 
     std::vector<char> precompiled_charsmap;
 
+    llm_tokenizer * tokenizer = nullptr;
+
+    llama_vocab() = default;
+    ~llama_vocab();
+
     int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
+
+    void init_tokenizer();
 };
 
 //
index f450eaf9ddc6fe9b47e7b6bcb8c2eaea9a6b05ab..44afb31d74e531e54621bd1dd53a0c4f722f6632 100644 (file)
@@ -6464,6 +6464,8 @@ static void llm_load_vocab(
     }
     GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
 
+    vocab.init_tokenizer();
+
     // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
     if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
         // For Fill-In-the-Middle (FIM)/infill models which where converted
index d3d21331bfd3d1a177d88ac632b67d6d0f5edb08..4d49850c9ea25797488113b17d8fc41830cfa0ab 100644 (file)
@@ -7,6 +7,7 @@
 #include <map>
 #include <vector>
 #include <fstream>
+#include <thread>
 
 //static const std::map<std::string, std::vector<llama_token>> & k_tests() {
 //    static std::map<std::string, std::vector<llama_token>> _k_tests = {
@@ -194,45 +195,64 @@ int main(int argc, char **argv) {
 
     const bool add_special = false;
 
-    for (const auto & test_kv : k_tests) {
-        const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special, false);
-
-        printf("\n");
-        printf("src: '%s'\n", test_kv.first.c_str());
-        printf("res: '%s'\n", llama_detokenize(ctx, res).c_str());
-        printf("tok: ");
-        for (const auto & tok : res) {
-            printf("%d ", tok);
-        }
-        printf("\n");
-
-        bool correct = res.size() == test_kv.second.size();
-        for (int i = 0; i < (int) res.size() && correct; ++i) {
-            if (test_kv.second[i] != res[i]) {
-                correct = false;
+    // multi-threaded tokenization
+    const int nthread = std::thread::hardware_concurrency();
+    std::vector<std::thread> threads(nthread);
+
+    for (int i = 0; i < nthread; i++) {
+        threads[i] = std::thread([&, i]() {
+            for (const auto & test_kv : k_tests) {
+                const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special, false);
+
+                // here only print the result of the first thread
+                // because the other threads are running the same tests
+                if (i != 0) {
+                    continue;
+                }
+
+                printf("\n");
+                printf("src: '%s'\n", test_kv.first.c_str());
+                printf("res: '%s'\n", llama_detokenize(ctx, res).c_str());
+                printf("tok: ");
+                for (const auto & tok : res) {
+                    printf("%d ", tok);
+                }
+                printf("\n");
+
+                bool correct = res.size() == test_kv.second.size();
+                for (int i = 0; i < (int) res.size() && correct; ++i) {
+                    if (test_kv.second[i] != res[i]) {
+                        correct = false;
+                    }
+                }
+
+                if (!correct) {
+                    fprintf(stderr, "%s : failed test:    '%s'\n", __func__, test_kv.first.c_str());
+                    fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
+                        llama_detokenize(ctx, res).c_str(),
+                        llama_detokenize(ctx, test_kv.second).c_str());
+                    fprintf(stderr, "%s : expected tokens: ", __func__);
+                    for (const auto & t : test_kv.second) {
+                        fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
+                    }
+                    fprintf(stderr, "\n");
+                    fprintf(stderr, "%s : got tokens:      ", __func__);
+                    for (const auto & t : res) {
+                        fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
+                    }
+                    fprintf(stderr, "\n");
+
+                    success = false;
+                }
             }
-        }
-
-        if (!correct) {
-            fprintf(stderr, "%s : failed test:    '%s'\n", __func__, test_kv.first.c_str());
-            fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
-                llama_detokenize(ctx, res).c_str(),
-                llama_detokenize(ctx, test_kv.second).c_str());
-            fprintf(stderr, "%s : expected tokens: ", __func__);
-            for (const auto & t : test_kv.second) {
-                fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
-            }
-            fprintf(stderr, "\n");
-            fprintf(stderr, "%s : got tokens:      ", __func__);
-            for (const auto & t : res) {
-                fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
-            }
-            fprintf(stderr, "\n");
+        });
+    }
 
-            success = false;
-        }
+    for (int i = 0; i < nthread; i++) {
+        threads[i].join();
     }
 
+    // single threaded tokenization
     if (!fname_text.empty()) {
         fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str());