]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Detokenizer fixes (#8039)
authorjaime-m-p <redacted>
Fri, 5 Jul 2024 17:01:35 +0000 (19:01 +0200)
committerGitHub <redacted>
Fri, 5 Jul 2024 17:01:35 +0000 (19:01 +0200)
* Add llama_detokenize():
  - Update header files location
  - UNKNOWN and CONTROL are 'special pieces'
  - Remove space after UNKNOWN and CONTROL
  - Refactor llama_token_to_piece()
  - Add flag: clean_up_tokenization_spaces
  - Symmetric params for llama_tokenize() and llama_detokenize()

* Update and fix tokenizer tests:
  - Using llama_detokenize()
  - Unexpected vocab type as test fail instead of error
    - Useful when automating tests:
    - If you don't know in advance the vocab type
    - Differenciate other loading errors
  - Skip unicode surrogaes and undefined
  - Gracefully exit threads
    - Using exit() is throwing random exceptions
  - Clean old known problematic codepoints
  - Minor: confusing hexadecimal codepoint

* Update bruteforce random tests
  - Add detokenizer checks
  - New generator: ascii_lr_strip
  - New generator: apostrophe
  - Add more vocabs files
  - Detokenize special tokens.
  - Replace errors with '\uFFFD' when detokenizing to 'utf-8'
  - More edge cases
  - Better detokenization results check

* Fix add_space_prefix, set false by default
* Better leading space removal
* Do not remove space when decoding special tokens
* Bugfix: custom regexs splits undefined unicode codepoints
* 'viking' detokenizer clean spaces

common/common.cpp
common/common.h
examples/batched.swift/Sources/main.swift
examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
include/llama.h
src/llama.cpp
src/unicode.cpp
tests/test-tokenizer-0.cpp
tests/test-tokenizer-1-bpe.cpp
tests/test-tokenizer-1-spm.cpp
tests/test-tokenizer-random.py

index 42ae85f5f11c214ab56d090caa28e5026e9fc0ed..c548bcb2857a8ad1e127c279059834a69ae5b3ce 100644 (file)
@@ -2592,51 +2592,35 @@ std::vector<llama_token> llama_tokenize(
 }
 
 std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
-    std::vector<char> result(8, 0);
-    const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
-    if (n_tokens < 0) {
-        result.resize(-n_tokens);
-        int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
-        GGML_ASSERT(check == -n_tokens);
-    } else {
-        result.resize(n_tokens);
-    }
-
-    return std::string(result.data(), result.size());
-}
-
-std::string llama_detokenize_spm(llama_context * ctx, const std::vector<llama_token> & tokens) {
-    const llama_token bos_id = llama_token_bos(llama_get_model(ctx));
-
     std::string piece;
-    std::string result;
-
-    for (size_t i = 0; i < tokens.size(); ++i) {
-        piece = llama_token_to_piece(ctx, tokens[i]);
-
-        // remove the leading space of the first non-BOS token
-        if (((tokens[0] == bos_id && i == 1) || (tokens[0] != bos_id && i == 0)) && piece[0] == ' ') {
-            piece = piece.substr(1);
-        }
-
-        result += piece;
+    piece.resize(piece.capacity());  // using string internal cache, 15 bytes + '\n'
+    const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
+    if (n_chars < 0) {
+        piece.resize(-n_chars);
+        int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
+        GGML_ASSERT(check == -n_chars);
+    }
+    else {
+        piece.resize(n_chars);
     }
 
-    return result;
+    return piece;
 }
 
-std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_token> & tokens) {
-    std::string piece;
-    std::string result;
-
-    for (size_t i = 0; i < tokens.size(); ++i) {
-        piece = llama_token_to_piece(ctx, tokens[i]);
-
-        result += piece;
+std::string llama_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
+    std::string text;
+    text.resize(std::max(text.capacity(), tokens.size()));
+    int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+    if (n_chars < 0) {
+        text.resize(-n_chars);
+        n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+        GGML_ASSERT(n_chars <= (int32_t)text.size());  // whitespace trimming is performed after per-token detokenization
     }
 
+    text.resize(n_chars);
+
     // NOTE: the original tokenizer decodes bytes after collecting the pieces.
-    return result;
+    return text;
 }
 
 bool llama_should_add_bos_token(const llama_model * model) {
index ac86483ffa1c946d6e5b43bd9f64726df3accf67..dabf02b598b3b0519ad55165d548e38c896cb10b 100644 (file)
@@ -350,21 +350,13 @@ std::string llama_token_to_piece(
                        llama_token   token,
                        bool          special = true);
 
-// TODO: these should be moved in llama.h C-style API under single `llama_detokenize` function
-//       that takes into account the tokenizer type and decides how to handle the leading space
-//
-// detokenizes a vector of tokens into a string
-// should work similar to Python's `tokenizer.decode`
-// removes the leading space from the first non-BOS token
-std::string llama_detokenize_spm(
-                         llama_context * ctx,
-        const std::vector<llama_token> & tokens);
-
 // detokenizes a vector of tokens into a string
 // should work similar to Python's `tokenizer.decode`
-std::string llama_detokenize_bpe(
+// optionally renders special/control tokens
+std::string llama_detokenize(
                          llama_context * ctx,
-        const std::vector<llama_token> & tokens);
+        const std::vector<llama_token> & tokens,
+                                  bool   special = true);
 
 // Uses the value from the model metadata if possible, otherwise
 // defaults to true when model type is SPM, otherwise false.
index dbbd06da5818364178a2fcc2c986100a5c2a2f72..616494d2d841d099af6332f34765121eff81a951 100644 (file)
@@ -229,7 +229,7 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
 
 private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String? {
     var result = [CChar](repeating: 0, count: 8)
-    let nTokens = llama_token_to_piece(model, token, &result, Int32(result.count), false)
+    let nTokens = llama_token_to_piece(model, token, &result, Int32(result.count), 0, false)
     if nTokens < 0 {
         let actualTokensCount = -Int(nTokens)
         result = .init(repeating: 0, count: actualTokensCount)
@@ -238,6 +238,7 @@ private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String
             token,
             &result,
             Int32(result.count),
+            0,
             false
         )
         assert(check == actualTokensCount)
index 737f882fb2d2eeb2b7c6ea3b3639e0f8cd0f2f0d..2a3f9f75890fc75783d31f1b8ae3fc0c5aa50aeb 100644 (file)
@@ -322,7 +322,7 @@ actor LlamaContext {
         defer {
             result.deallocate()
         }
-        let nTokens = llama_token_to_piece(model, token, result, 8, false)
+        let nTokens = llama_token_to_piece(model, token, result, 8, 0, false)
 
         if nTokens < 0 {
             let newResult = UnsafeMutablePointer<Int8>.allocate(capacity: Int(-nTokens))
@@ -330,7 +330,7 @@ actor LlamaContext {
             defer {
                 newResult.deallocate()
             }
-            let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens, false)
+            let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens, 0, false)
             let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens))
             return Array(bufferPointer)
         } else {
index 7a9a25609c9c1949794f79bc772a533153da26c2..865ace9944d02064a518bfdba4ebb794c127a52f 100644 (file)
@@ -904,6 +904,7 @@ extern "C" {
     /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
     /// @return Returns the number of tokens on success, no more than n_tokens_max
     /// @return Returns a negative number on failure - the number of tokens that would have been returned
+    /// @param add_special Allow to add BOS and EOS tokens if model is configured to do so.
     /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
     ///                      as plaintext. Does not insert a leading space.
     LLAMA_API int32_t llama_tokenize(
@@ -918,15 +919,31 @@ extern "C" {
     // Token Id -> Piece.
     // Uses the vocabulary in the provided context.
     // Does not write null terminator to the buffer.
-    // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
+    // User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix')
     // @param special If true, special tokens are rendered in the output.
     LLAMA_API int32_t llama_token_to_piece(
               const struct llama_model * model,
                            llama_token   token,
                                   char * buf,
                                int32_t   length,
+                               int32_t   lstrip,
                                   bool   special);
 
+    /// @details Convert the provided tokens into text (inverse of llama_tokenize()).
+    /// @param text The char pointer must be large enough to hold the resulting text.
+    /// @return Returns the number of chars/bytes on success, no more than text_len_max.
+    /// @return Returns a negative number on failure - the number of chars/bytes that would have been returned.
+    /// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so.
+    /// @param unparse_special If true, special tokens are rendered in the output.
+    LLAMA_API int32_t llama_detokenize(
+        const struct llama_model * model,
+               const llama_token * tokens,
+                         int32_t   n_tokens,
+                            char * text,
+                         int32_t   text_len_max,
+                            bool   remove_special,
+                            bool   unparse_special);
+
     /// Apply chat template. Inspired by hf apply_chat_template() on python.
     /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
     /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
index 18956d441409fb47ccb6df5f950eb28488b9e787..b770ca5bc33fc6621b653c49e07e35274e24cb1e 100644 (file)
@@ -1995,18 +1995,19 @@ using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
 
 // NOTE: avoid ever using this except for building the token_to_piece caches
 static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
-    std::vector<char> result(8, 0);
-    const int n_tokens = llama_token_to_piece(model, token, result.data(), result.size(), special);
-    if (n_tokens < 0) {
-        result.resize(-n_tokens);
-        int check = llama_token_to_piece(model, token, result.data(), result.size(), special);
-        GGML_ASSERT(check == -n_tokens);
+    std::string piece;
+    piece.resize(piece.capacity());  // using string internal cache
+    const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
+    if (n_chars < 0) {
+        piece.resize(-n_chars);
+        int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
+        GGML_ASSERT(check == -n_chars);
     }
     else {
-        result.resize(n_tokens);
+        piece.resize(n_chars);
     }
 
-    return std::string(result.data(), result.size());
+    return piece;
 }
 
 static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer) {
@@ -2586,10 +2587,11 @@ struct llama_vocab {
     id special_eot_id    = -1; // TODO: move above after "eos_id", and here add "file separator" token
 
     // tokenizer flags
-    bool tokenizer_add_space_prefix = true;
+    bool tokenizer_add_space_prefix = false;
     bool tokenizer_add_bos          = false;
     bool tokenizer_add_eos          = false;
     bool tokenizer_ignore_merges    = false;
+    bool tokenizer_clean_spaces     = false;  // clean_up_tokenization_spaces
     bool tokenizer_remove_extra_whitespaces   = false;
     bool tokenizer_escape_whitespaces         = true;
     bool tokenizer_treat_whitespace_as_suffix = false;
@@ -5230,11 +5232,6 @@ static void llm_load_vocab(
             vocab.special_pad_id  = -1;
             vocab.special_cls_id  = -1;
             vocab.special_mask_id = -1;
-
-            const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
-            if (add_space_prefix_keyidx != -1) {
-                vocab.tokenizer_add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
-            } // The default value of add_space_prefix is true.
         } else if (tokenizer_model == "bert") {
             vocab.type = LLAMA_VOCAB_TYPE_WPM;
 
@@ -5246,15 +5243,9 @@ static void llm_load_vocab(
             vocab.special_pad_id  = 0;
             vocab.special_cls_id  = 101;
             vocab.special_mask_id = 103;
-            vocab.tokenizer_add_space_prefix = false;
         } else if (tokenizer_model == "gpt2") {
             vocab.type = LLAMA_VOCAB_TYPE_BPE;
 
-            const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
-            if (add_space_prefix_keyidx != -1) {
-                vocab.tokenizer_add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
-            }
-
             // read bpe merges and populate bpe ranks
             const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
             if (merges_keyidx == -1) {
@@ -5333,6 +5324,8 @@ static void llm_load_vocab(
 
         // for now, only BPE models have pre-tokenizers
         if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
+            vocab.tokenizer_add_space_prefix = false;
+            vocab.tokenizer_clean_spaces = true;
             if (tokenizer_pre.empty()) {
                 LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
                 LLAMA_LOG_WARN("%s:                                             \n", __func__);
@@ -5354,9 +5347,11 @@ static void llm_load_vocab(
             } else if (
                     tokenizer_pre == "deepseek-llm") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM;
+                vocab.tokenizer_clean_spaces = false;
             } else if (
                     tokenizer_pre == "deepseek-coder") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER;
+                vocab.tokenizer_clean_spaces = false;
             } else if (
                     tokenizer_pre == "falcon") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON;
@@ -5368,6 +5363,7 @@ static void llm_load_vocab(
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STARCODER;
             } else if (
                     tokenizer_pre == "gpt-2"   ||
+                    tokenizer_pre == "phi-2"   ||
                     tokenizer_pre == "jina-es" ||
                     tokenizer_pre == "jina-de" ||
                     tokenizer_pre == "jina-v2-es" ||
@@ -5383,6 +5379,7 @@ static void llm_load_vocab(
             } else if (
                 tokenizer_pre == "qwen2") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2;
+                vocab.tokenizer_clean_spaces = false;
             } else if (
                 tokenizer_pre == "stablelm2") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STABLELM2;
@@ -5398,9 +5395,11 @@ static void llm_load_vocab(
             } else if (
                 tokenizer_pre == "poro-chat") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO;
+                vocab.tokenizer_clean_spaces = false;
             } else if (
                 tokenizer_pre == "viking") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING;
+                vocab.tokenizer_clean_spaces = false;
             } else if (
                 tokenizer_pre == "jais") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
@@ -5409,10 +5408,14 @@ static void llm_load_vocab(
             }
         } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
             vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            vocab.tokenizer_add_space_prefix = true;
+            vocab.tokenizer_clean_spaces = false;
             vocab.tokenizer_add_bos = true;
             vocab.tokenizer_add_eos = false;
         } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
             vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            vocab.tokenizer_add_space_prefix = false;
+            vocab.tokenizer_clean_spaces = true;
             vocab.tokenizer_add_bos = true;
             vocab.tokenizer_add_eos = false;
         } else if (vocab.type == LLAMA_VOCAB_TYPE_UGM) {
@@ -5422,6 +5425,11 @@ static void llm_load_vocab(
         } else {
             vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
         }
+
+        const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
+        if (add_space_prefix_keyidx != -1) {
+            vocab.tokenizer_add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
+        }
     }
 
     const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
@@ -5603,7 +5611,7 @@ static void llm_load_vocab(
             }
         }
 
-        std::sort( vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
+        std::sort(vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
             [&] (const llama_vocab::id a, const llama_vocab::id b) {
                 return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
             }
@@ -16098,7 +16106,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
                 // tokenizer.encode('', add_special_tokens=True)  returns [1]
                 // tokenizer.encode('', add_special_tokens=False) returns []
 
-                bool is_prev_special = false;
+                bool is_prev_special = true;  // prefix with space if first token
 
                 if (add_special && vocab.tokenizer_add_bos) {
                     GGML_ASSERT(vocab.special_bos_id != -1);
@@ -16110,10 +16118,9 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
                         auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
 
-                        if (vocab.tokenizer_add_space_prefix) {
-                            if (!output.size() || is_prev_special) {  // prefix with space if first token
-                                raw_text = " " + raw_text;
-                            }
+                        // prefix with space if previous is special
+                        if (vocab.tokenizer_add_space_prefix && is_prev_special) {
+                            raw_text = " " + raw_text;
                         }
 
 #ifdef PRETOKENIZERDEBUG
@@ -16122,6 +16129,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
                         llm_tokenizer_spm tokenizer(vocab);
                         llama_escape_whitespace(raw_text);
                         tokenizer.tokenize(raw_text, output);
+                        is_prev_special = false;
                     } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
                         output.push_back(fragment.token);
                         is_prev_special = true;
@@ -20904,85 +20912,66 @@ static std::string llama_decode_text(const std::string & text) {
 }
 
 // does not write null-terminator to buf
-int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
+int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) {
     // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
-    if (!special && llama_is_control_token(model->vocab, token)) {
+    static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
+    const llama_token_attr attr = llama_token_get_attr(model, token);
+    if (!special && (attr & attr_special)) {
         return 0;
     }
 
+    // copy piece chars to output text buffer
+    // skip up to 'lstrip' leading spaces before copying
+    auto _try_copy = [=] (const char * token, size_t size) -> int32_t {
+        for (int32_t i = 0; i < lstrip && size && *token == ' '; ++i) {
+            token++;
+            size--;
+        }
+        if (length < (int32_t)size) {
+            return (int32_t) -size;
+        }
+        memcpy(buf, token, size);
+        return (int32_t) size;
+    };
+
     // if we have a cache - use it
     {
         const auto & cache = model->vocab.cache_token_to_piece;
 
         if (!cache.empty()) {
-            const auto & res = cache.at(token);
-            if (length < (int) res.size()) {
-                return -(int) res.size();
-            }
-            memcpy(buf, res.c_str(), res.size());
-            return res.size();
+            const auto & result = cache.at(token);
+            return _try_copy(result.data(), result.size());
         }
     }
 
     if (0 <= token && token < llama_n_vocab(model)) {
+        const std::string & token_text = model->vocab.id_to_token[token].text;
         switch (llama_vocab_get_type(model->vocab)) {
             case LLAMA_VOCAB_TYPE_WPM:
             case LLAMA_VOCAB_TYPE_SPM:
             case LLAMA_VOCAB_TYPE_UGM: {
                 // NOTE: we accept all unsupported token types,
                 // suppressing them like CONTROL tokens.
-                if (llama_is_normal_token(model->vocab, token)) {
-                    std::string result = model->vocab.id_to_token[token].text;
+                if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
+                    return _try_copy(token_text.data(), token_text.size());
+                } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
+                    std::string result = token_text;
                     llama_unescape_whitespace(result);
-                    if (length < (int) result.length()) {
-                        return -(int) result.length();
-                    }
-                    memcpy(buf, result.c_str(), result.length());
-                    return result.length();
-                } else if (
-                        (llama_is_user_defined_token(model->vocab, token)) ||
-                        (llama_is_control_token     (model->vocab, token) && special)) {
-                    std::string result = model->vocab.id_to_token[token].text;
-                    if (length < (int) result.length()) {
-                        return -(int) result.length();
-                    }
-                    memcpy(buf, result.c_str(), result.length());
-                    return result.length();
-                } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
-                    if (length < 3) {
-                        return -3;
-                    }
-                    memcpy(buf, "\xe2\x96\x85", 3);
-                    return 3;
-                } else if (llama_is_byte_token(model->vocab, token)) {
-                    if (length < 1) {
-                        return -1;
-                    }
-                    buf[0] = llama_token_to_byte(model->vocab, token);
-                    return 1;
+                    return _try_copy(result.data(), result.size());
+                } else if (attr & LLAMA_TOKEN_ATTR_BYTE) {
+                    char byte = (char) llama_token_to_byte(model->vocab, token);
+                    return _try_copy((char*) &byte, 1);
                 }
                 break;
             }
             case LLAMA_VOCAB_TYPE_BPE: {
                 // NOTE: we accept all unsupported token types,
                 // suppressing them like CONTROL tokens.
-                if (llama_is_normal_token(model->vocab, token)) {
-                    std::string result = model->vocab.id_to_token[token].text;
-                    result = llama_decode_text(result);
-                    if (length < (int) result.length()) {
-                        return -(int) result.length();
-                    }
-                    memcpy(buf, result.c_str(), result.length());
-                    return result.length();
-                } else if (
-                        (llama_is_user_defined_token(model->vocab, token)) ||
-                        (llama_is_control_token     (model->vocab, token) && special)) {
-                    std::string result = model->vocab.id_to_token[token].text;
-                    if (length < (int) result.length()) {
-                        return -(int) result.length();
-                    }
-                    memcpy(buf, result.c_str(), result.length());
-                    return result.length();
+                if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
+                    return _try_copy(token_text.data(), token_text.size());
+                } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
+                    std::string result = llama_decode_text(token_text);
+                    return _try_copy(result.data(), result.size());
                 }
                 break;
             }
@@ -20993,6 +20982,113 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
     return 0;
 }
 
+int32_t llama_detokenize(
+        const struct llama_model * model,
+               const llama_token * tokens,
+                         int32_t   n_tokens,
+                            char * text,
+                         int32_t   text_len_max,
+                            bool   remove_special,
+                            bool   unparse_special) {
+    int32_t avail = text_len_max;
+    int32_t total = 0;
+
+    // remove the leading space
+    bool remove_space = model->vocab.tokenizer_add_space_prefix;
+
+    if (remove_special && model->vocab.tokenizer_add_bos) {
+        if (n_tokens > 0 && tokens[0] == model->vocab.special_bos_id) {
+            remove_space = false;
+            n_tokens--;
+            tokens++;
+        }
+    }
+
+    if (remove_special && model->vocab.tokenizer_add_eos) {
+        if (n_tokens > 0 && tokens[n_tokens-1] == model->vocab.special_eos_id) {
+            n_tokens--;
+        }
+    }
+
+    for (int32_t i = 0; i < n_tokens; ++i) {
+        GGML_ASSERT(avail >= 0);
+        int32_t n_chars = llama_token_to_piece(model, tokens[i], text, avail, remove_space, unparse_special);
+        remove_space = false;
+        if (n_chars < 0) {
+            avail = 0;
+            total -= n_chars;
+        } else if (n_chars > 0) {
+            avail -= n_chars;
+            text  += n_chars;
+            total += n_chars;
+        }
+    }
+
+    if (total > text_len_max) {
+        return -total;
+    }
+
+    if (model->vocab.tokenizer_clean_spaces) {
+        text -= total;  // restart text
+
+        // first pass: characters ?!.,  //TODO: where do these characters come from?
+        const int32_t total1 = total;
+        total = total ? 1 : 0;
+        for (int32_t i = 1; i < total1; ++i) {
+            const char x = text[i];
+            if (text[i - 1] == ' ') {
+                if (x == '?' || x == '!' || x == '.' || x == ',') {  // " ?", " !", " .", " ,"
+                    total--;  // remove space
+                }
+            }
+            text[total++] = x;
+        }
+
+        // second pass: strip single apostrophe between spaces
+        const int32_t total2 = total;
+        total = total ? 1 : 0;
+        for (int32_t i = 1; i < total2; ++i) {
+            const char x = text[i];
+            if (x == '\'' && i + 1 < total2 && text[i - 1] == ' ' && text[i + 1] == ' ') {  // " ' "
+                total--;           // remove prev space
+                text[++i] = '\0';  // remove next space
+            }
+            text[total++] = x;
+        }
+
+        // third pass: apostrophe contractions  //NOTE: this makes sense?
+        const int32_t total3 = total;
+        total = total ? 1 : 0;
+        for (int32_t i = 1; i < total3; ++i) {
+            const char x = text[i];
+            if (text[i - 1] == ' ') {
+                if (x == '\'' && i + 1 < total3) {
+                    const char x1 = text[i + 1];
+                    if (x1 == 't' || x1 == 'd') {  // " 't", " 'd"
+                        //total--;  // remove space
+                    } else if (x1 == 's' || x1 == 'm') {  // " 's", " 'm"
+                        total--;  // remove space
+                    } else if (i + 2 < total3) {
+                        const char x2 = text[i + 2];
+                        if ((x1 == 'l' && x2 == 'l')) {  // " 'll"
+                            //total--;  // remove space
+                        } else if ((x1 == 'r' && x2 == 'e') || (x1 == 'v' && x2 == 'e')) {  // " 're", " 've"
+                            total--;  // remove space
+                        } else {
+                            //total--;  // remove space
+                        }
+                    } else {
+                        //total--;  // remove space
+                    }
+                }
+            }
+            text[total++] = x;
+        }
+    }
+
+    return total <= text_len_max ? total : -total;
+}
+
 // trim whitespace from the beginning and end of a string
 static std::string trim(const std::string & str) {
     size_t start = 0;
index 8692924b957ccf903c24aa26bb99bcd2f2593766..51daa15afa66927d5cf984474b941896fc9edd9f 100644 (file)
@@ -232,8 +232,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
         };
 
         auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
-            static const codepoint_flags undef(codepoint_flags::UNDEFINED);
-            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
         };
 
         size_t _prev_end = offset_ini;
@@ -295,9 +294,9 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
                 continue;
             }
             // regex: <space>?[^\s\p{L}\p{N}]+
-            if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
+            if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
                 pos += (cpt == ' ');
-                while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
+                while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
                     flags2 = _get_flags(++pos);
                 }
                 _add_token(pos);
@@ -351,8 +350,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
         };
 
         auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
-            static const codepoint_flags undef(codepoint_flags::UNDEFINED);
-            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
         };
 
         size_t _prev_end = offset_ini;
@@ -394,8 +392,8 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
                 }
             }
 
-            // regex: [^\r\n\p{L}\p{N}]?\p{L}+  //####FIXME: the first \p{L} is correct?
-            if (!(cpt == '\r' || cpt == '\n' || /*flags.is_letter |*/ flags.is_number)) {
+            // regex: [^\r\n\p{L}\p{N}]?\p{L}+
+            if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
                 if (flags.is_letter || _get_flags(pos+1).is_letter) {  // one or more letters
                     pos++;
                     while (_get_flags(pos).is_letter) {
@@ -421,9 +419,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
 
             // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
             auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
-            if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
+            if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
                 pos += (cpt == ' ');
-                while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
+                while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
                     flags2 = _get_flags(++pos);
                 }
                 uint32_t cpt2 = _get_cpt(pos);
index d478f104148a6cebc4ef3a8d67a7b978dcf07ac7..1f04b6f34ad7e488fec537a6a888b7d1dd86a460 100644 (file)
@@ -195,11 +195,11 @@ int main(int argc, char **argv) {
     const bool add_special = false;
 
     for (const auto & test_kv : k_tests) {
-        const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special);
+        const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special, true);
 
         printf("\n");
         printf("src: '%s'\n", test_kv.first.c_str());
-        printf("res: '%s'\n", llama_detokenize_bpe(ctx, res).c_str());
+        printf("res: '%s'\n", llama_detokenize(ctx, res).c_str());
         printf("tok: ");
         for (const auto & tok : res) {
             printf("%d ", tok);
@@ -216,8 +216,8 @@ int main(int argc, char **argv) {
         if (!correct) {
             fprintf(stderr, "%s : failed test:    '%s'\n", __func__, test_kv.first.c_str());
             fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
-                llama_detokenize_bpe(ctx, res).c_str(),
-                llama_detokenize_bpe(ctx, test_kv.second).c_str());
+                llama_detokenize(ctx, res).c_str(),
+                llama_detokenize(ctx, test_kv.second).c_str());
             fprintf(stderr, "%s : expected tokens: ", __func__);
             for (const auto & t : test_kv.second) {
                 fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
@@ -253,7 +253,7 @@ int main(int argc, char **argv) {
         {
             const auto t_start = ggml_time_us();
 
-            res = llama_tokenize(ctx, text, add_special);
+            res = llama_tokenize(ctx, text, add_special, true);
 
             const auto t_end = ggml_time_us();
 
@@ -272,7 +272,7 @@ int main(int argc, char **argv) {
             }
 
             for (const auto & tok : res) {
-                //ofs << tok << " '" << string_strip(llama_detokenize_bpe(ctx, std::vector<int>{tok})) << "'" << std::endl;
+                //ofs << tok << " '" << string_strip(llama_detokenize(ctx, std::vector<int>{tok})) << "'" << std::endl;
                 ofs << tok << "\n";
             }
         }
index 209a04ad6f77ad1840d4ad5d87d5ade8a5705b48..9498387e0f21210bacd6fdb6e1cde6c825419b0a 100644 (file)
@@ -11,6 +11,7 @@
 #include <string>
 #include <thread>
 #include <vector>
+#include <atomic>
 
 int main(int argc, char **argv) {
     if (argc < 2 || argc > 3) {
@@ -63,7 +64,10 @@ int main(int argc, char **argv) {
         }
     }
 
-    GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_BPE);
+    //GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_BPE);
+    if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) {
+        return 99;
+    }
 
 #ifdef _WIN32
     // We need this for unicode console support
@@ -74,7 +78,7 @@ int main(int argc, char **argv) {
     const int n_vocab = llama_n_vocab(model);
 
     for (int i = 0; i < n_vocab; ++i) {
-        std::string str = llama_detokenize_bpe(ctx, std::vector<int>(1, i));
+        std::string str = llama_detokenize(ctx, std::vector<int>(1, i));
         try {
             auto cps = unicode_cpts_from_utf8(str);
             std::vector<llama_token> tokens = llama_tokenize(ctx, str, false, true);
@@ -90,7 +94,7 @@ int main(int argc, char **argv) {
                 fprintf(stderr, "]\n");
                 return 2;
             }
-            std::string check = llama_detokenize_bpe(ctx, tokens);
+            std::string check = llama_detokenize(ctx, tokens);
             if (check != str) {
                 fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",
                     __func__, i, str.c_str(), str.length(), check.c_str(), check.length());
@@ -108,26 +112,23 @@ int main(int argc, char **argv) {
 
         std::vector<std::thread> threads(nthread);
 
+        std::atomic_int errcode = {};
+
         for (int i = 0; i < nthread; ++i) {
-            threads[i] = std::thread([i, nthread, ctx]() {
-                for (uint32_t cp = i; cp < 0x0010ffff; cp += nthread) {
-                    if (!( // NOLINT
-                                (cp < 0x03       || cp >  0x05)   && cp != 0x0b && cp != 0x11 &&
-                                (cp < 0x13       || cp >  0x17)   && cp != 0x19 &&
-                                (cp < 0x1c       || cp >  0x1e)   &&
-                                (cp < 0xd800     || cp >  0xdfff) &&
-                                (cp < 0x00040000 || cp >= 0x000e0000)
-                        )) {
+            threads[i] = std::thread([i, nthread, ctx, &errcode]() {
+                for (uint32_t cp = i; !errcode && cp < 0x00110000; cp += nthread) {
+                    if ((0x0000D800 <= cp && cp <= 0x0000DFFF) ||  // surrogates \p{Cs}
+                        (0x00040000 <= cp && cp <= 0x000E0000)) {  // undefined  \p{Cn}
                         continue;
                     }
 
                     std::string str = unicode_cpt_to_utf8(cp);
                     std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
-                    std::string check = llama_detokenize_bpe(ctx, tokens);
+                    std::string check = llama_detokenize(ctx, tokens);
                     if (cp != 9601 && str != check) {
-                        fprintf(stderr, "error: codepoint %x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
+                        fprintf(stderr, "error: codepoint 0x%x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
                                 cp, check.c_str(), check.length(), str.c_str(), str.length());
-                        std::exit(3);
+                        errcode = 3;
                     }
                 }
             });
@@ -136,6 +137,10 @@ int main(int argc, char **argv) {
         for (auto & t : threads) {
             t.join();
         }
+
+        if (errcode) {
+            return errcode;
+        }
     }
 
     llama_free_model(model);
index ac2333ddaf886f7ce334f443b9bd53a6a38849c1..7ca9e2ca6a671e7e5c01a8126fcdb6e23a271c91 100644 (file)
@@ -11,6 +11,7 @@
 #include <string>
 #include <thread>
 #include <vector>
+#include <atomic>
 
 int main(int argc, char ** argv) {
     if (argc < 2) {
@@ -51,7 +52,10 @@ int main(int argc, char ** argv) {
         }
     }
 
-    GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
+    //GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
+    if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_SPM) {
+        return 99;
+    }
 
 #ifdef _WIN32
     // We need this for unicode console support
@@ -62,9 +66,9 @@ int main(int argc, char ** argv) {
     const int n_vocab = llama_n_vocab(model);
 
     for (int i = 0; i < n_vocab; ++i) {
-        std::string str = llama_detokenize_spm(ctx, std::vector<int>(1, i));
-        std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
-        std::string check = llama_detokenize_spm(ctx, tokens);
+        std::string str = llama_detokenize(ctx, std::vector<int>(1, i), true);
+        std::vector<llama_token> tokens = llama_tokenize(ctx, str, false, true);
+        std::string check = llama_detokenize(ctx, tokens);
         if (check != str) {
             fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",
                 __func__, i, str.c_str(), str.length(), check.c_str(), check.length());
@@ -78,20 +82,23 @@ int main(int argc, char ** argv) {
 
         std::vector<std::thread> threads(nthread);
 
+        std::atomic_int errcode = {};
+
         for (int i = 0; i < nthread; ++i) {
-            threads[i] = std::thread([i, nthread, ctx]() {
-                for (uint32_t cp = i; cp < 0x0010ffff; cp += nthread) {
-                    if (cp >= 0xd800 && cp <= 0xdfff) {
+            threads[i] = std::thread([i, nthread, ctx, &errcode]() {
+                for (uint32_t cp = i; !errcode && cp < 0x00110000; cp += nthread) {
+                    if ((0x0000D800 <= cp && cp <= 0x0000DFFF) ||  // surrogates \p{Cs}
+                        (0x00040000 <= cp && cp <= 0x000E0000)) {  // undefined \p{Cn}
                         continue;
                     }
 
                     std::string str = unicode_cpt_to_utf8(cp);
-                    std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
-                    std::string check = llama_detokenize_spm(ctx, tokens);
+                    std::vector<llama_token> tokens = llama_tokenize(ctx, str, false, true);
+                    std::string check = llama_detokenize(ctx, tokens);
                     if (cp != 9601 && str != check) {
-                        fprintf(stderr, "error: codepoint %x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
+                        fprintf(stderr, "error: codepoint 0x%x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
                                 cp, check.c_str(), check.length(), str.c_str(), str.length());
-                        std::exit(3);
+                        errcode = 3;
                     }
                 }
             });
@@ -100,6 +107,10 @@ int main(int argc, char ** argv) {
         for (auto & t : threads) {
             t.join();
         }
+
+        if(errcode) {
+            return errcode;
+        }
     }
 
     llama_free_model(model);
index a07c52fb3fc60d9c9a6e2c6b098c60fd455d1909..48cab8a1e0859079cc073f2635136db831c2d723 100644 (file)
@@ -13,7 +13,7 @@ import subprocess
 import random
 import unicodedata
 
-from typing import Callable, Iterator
+from typing import Iterator
 
 import cffi
 from transformers import AutoTokenizer
@@ -24,17 +24,20 @@ logger = logging.getLogger("test-tokenizer-random")
 
 class LibLlama:
 
-    DEFAULT_PATH_LLAMA_H = "./llama.h"
-    DEFAULT_PATH_LIBLLAMA = "./build/libllama.so"  # CMakeLists.txt: BUILD_SHARED_LIBS ON
+    DEFAULT_PATH_LLAMA_H = "./include/llama.h"
+    DEFAULT_PATH_INCLUDES = ["./ggml/include/", "./include/"]
+    DEFAULT_PATH_LIBLLAMA = "./build/src/libllama.so"  # CMakeLists.txt: BUILD_SHARED_LIBS ON
 
-    def __init__(self, path_llama_h: str = None, path_libllama: str = None):
+    def __init__(self, path_llama_h: str = None, path_includes: list[str] = [], path_libllama: str = None):
         path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H
+        path_includes = path_includes or self.DEFAULT_PATH_INCLUDES
         path_libllama = path_libllama or self.DEFAULT_PATH_LIBLLAMA
-        (self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_libllama)
+        (self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_includes, path_libllama)
         self.lib.llama_backend_init()
 
-    def _load_libllama_cffi(self, path_llama_h: str, path_libllama: str):
-        cmd = ["gcc", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)=", path_llama_h]
+    def _load_libllama_cffi(self, path_llama_h: str, path_includes: list[str], path_libllama: str):
+        cmd = ["gcc", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)="]
+        cmd += ["-I" + path for path in path_includes] + [path_llama_h]
         res = subprocess.run(cmd, stdout=subprocess.PIPE)
         assert (res.returncode == 0)
         source = res.stdout.decode()
@@ -79,6 +82,7 @@ class LibLlamaModel:
             raise RuntimeError("error: failed to create context for model '%s'" % path_model)
         n_tokens_max = self.lib.llama_n_ctx(self.ctx)
         self.token_ids = self.ffi.new("llama_token[]", n_tokens_max)
+        self.text_buff = self.ffi.new("uint8_t[]", 1024)
 
     def free(self):
         if self.ctx:
@@ -89,14 +93,78 @@ class LibLlamaModel:
         self.model = None
         self.lib = None
 
-    def tokenize(self, text: str, n_tokens_max: int = 0, add_special: bool = False, parse_special: bool = False) -> list[int]:
-        n_tokens_max = n_tokens_max if n_tokens_max > 0 else len(self.token_ids)
+    def tokenize(self, text: str, add_special: bool = False, parse_special: bool = False) -> list[int]:
         text = text.encode("utf-8")
-        num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, n_tokens_max, add_special, parse_special)
-        if num < 0:
-            return []
+        num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, len(self.token_ids), add_special, parse_special)
+        while num < 0 and len(self.token_ids) < (16 << 20):
+            self.token_ids = self.ffi.new("llama_token[]", -2 * num)
+            num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, len(self.token_ids), add_special, parse_special)
         return list(self.token_ids[0:num])
 
+    def detokenize(self, ids: list[int], remove_special: bool = False, unparse_special: bool = False) -> str:
+        if len(self.token_ids) < len(ids):
+            self.token_ids = self.ffi.new("llama_token[]", 2 * len(ids))
+        for i, id in enumerate(ids):
+            self.token_ids[i] = id
+        num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
+        while num < 0 and len(self.text_buff) < (16 << 20):
+            self.text_buff = self.ffi.new("uint8_t[]", -2 * num)
+            num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
+        return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8", errors="replace")  # replace errors with '\uFFFD'
+
+
+class Tokenizer:
+
+    def encode(self, text: str) -> list[int]:
+        raise NotImplementedError
+
+    def decode(self, ids: list[int]) -> str:
+        raise NotImplementedError
+
+
+class TokenizerGroundtruth (Tokenizer):
+
+    def __init__(self, dir_tokenizer: str):
+        self.model = AutoTokenizer.from_pretrained(dir_tokenizer)
+        # guess BOS and EOS
+        ids = self.encode("a")
+        assert 1 <= len(ids) <= 3
+        add_bos_token = len(ids) > 1 and self.model.bos_token_id == ids[0]
+        add_eos_token = len(ids) > 1 and self.model.eos_token_id == ids[-1]
+        self.add_bos_token = getattr(self.model, "add_bos_token", add_bos_token)
+        self.add_eos_token = getattr(self.model, "add_eos_token", add_eos_token)
+        # build vocab
+        tokens = list(self.model.get_vocab().values())
+        self.vocab = self.model.batch_decode(tokens, skip_special_tokens=True)
+        self.vocab = list(sorted(self.vocab))
+        # tokens and lists
+        self.special_tokens = list(self.model.all_special_tokens)
+        self.added_tokens   = list(self.model.added_tokens_encoder)
+        self.bos_token = self.model.bos_token
+        self.eos_token = self.model.eos_token
+
+    def encode(self, text: str) -> list[int]:
+        return self.model.encode(text, add_special_tokens=True)
+
+    def decode(self, ids: list[int]) -> str:
+        return self.model.decode(ids, skip_special_tokens=False)
+
+
+class TokenizerLlamaCpp (Tokenizer):
+
+    libllama: LibLlama = None
+
+    def __init__(self, vocab_file: str):
+        if not self.libllama:
+            self.libllama = LibLlama()
+        self.model = LibLlamaModel(self.libllama, vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
+
+    def encode(self, text: str) -> list[int]:
+        return self.model.tokenize(text, add_special=True, parse_special=True)
+
+    def decode(self, ids: list[int]) -> str:
+        return self.model.detokenize(ids, remove_special=False, unparse_special=True)
+
 
 def generator_custom_text() -> Iterator[str]:
     """General tests"""
@@ -165,19 +233,48 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
         'a </s> b',        # rstrip phi-3
         'a <mask> b',      # lstrip jina-v2
         '\xa0aC',          # deepseek
+        '\u2029 \uA3E4',   # deepseek-llm
+        "a ?",
+        'å',               # mpt
+        '\U000ac517',      # utf-8 encode error, falcon
+        '\U000522f4',      # utf-8 encode error, starcoder
+        "<s><s><unk><s>a<s>b<s>c<unk>d<unk></s>",
+        "<s> <s> <unk><s>a<s>b<s>c<unk>d<unk></s>",
     ]
 
 
-def generator_vocab_words(vocab: list[str]) -> Iterator[str]:
+def generator_vocab_words(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
     """Brute force check all vocab words"""
-    yield from vocab
-
-
-def generator_added_lr_strip(tokenizer) -> Iterator[str]:
-    WHITESPACES = ["", " ", "  ", "    "]
-    special_tokens = list(tokenizer.all_special_tokens)
-    added_tokens   = list(tokenizer.added_tokens_encoder)
-    all_tokens     = list(sorted(set(special_tokens + added_tokens)))
+    yield from tokenizer.vocab
+
+
+def generator_ascii_lr_strip() -> Iterator[str]:
+    WHITESPACES = ["", " ", "  "]
+    CHARACTERS = list(chr(i) for i in range(1, 0x80)) + [""]
+    for char1 in CHARACTERS:
+        for char2 in CHARACTERS:
+            for lstrip in WHITESPACES:
+                for rstrip in WHITESPACES:
+                    yield lstrip + char1 + char2 + rstrip
+                    yield lstrip + char1 + rstrip + char2
+                    yield char1 + lstrip + char2 + rstrip
+
+
+def generator_apostrophe() -> Iterator[str]:
+    WHITESPACES = ["", " ", "  "]
+    CHARACTERS = list(chr(i) for i in range(1, 0x80)) + [""]
+    for char1 in CHARACTERS:
+        for char2 in CHARACTERS:
+            for lstrip in WHITESPACES:
+                for rstrip in WHITESPACES:
+                    yield char1 + lstrip + "'" + rstrip + char2
+                    yield char1 + char2 + lstrip + "'" + rstrip + "z"
+                    yield "a" + lstrip + "'" + rstrip + char1 + char2
+
+
+def generator_added_lr_strip(tokenizer: TokenizerGroundtruth) -> Iterator[str]:
+    WHITESPACES = ["", " ", "  ", "\n", "\r\n", "\n\n", "\t", "\t\t"]
+    all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens)))
     for token in all_tokens:
         for lstrip in WHITESPACES:
             for rstrip in WHITESPACES:
@@ -187,11 +284,9 @@ def generator_added_lr_strip(tokenizer) -> Iterator[str]:
                 yield "a" + lstrip + token + rstrip + "z"
 
 
-def generator_random_added_tokens(tokenizer, iterations=100) -> Iterator[str]:
-    special_tokens = list(tokenizer.all_special_tokens)
-    added_tokens   = list(tokenizer.added_tokens_encoder)
-    separations    = [" ", "\n", "\t", "-", "!", "one", "1", "<s>", "</s>"]
-    all_tokens     = list(sorted(set(special_tokens + added_tokens + separations)))
+def generator_random_added_tokens(tokenizer: TokenizerGroundtruth, iterations=100) -> Iterator[str]:
+    separations = [" ", "\n", "\t", "-", "!", "one", "1", "<s>", "</s>"]
+    all_tokens  = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens + separations)))
     rand = random.Random()
     for m in range(iterations):
         rand.seed(m)
@@ -242,13 +337,13 @@ def generator_unicodes() -> Iterator[str]:
     def _valid(cpt):
         if cpt >= 0x30000:  # unassigned and supplement­ary
             return False
-        if 0x00D800 <= cpt <= 0x00F8FF:  # Surrogates
-            return False
-        if unicodedata.category(chr(cpt)) == "Cn":
+        # if cpt == 0x2029:  # deepseek-llm
+        #    return False
+        if unicodedata.category(chr(cpt)) in ("Cn", "Cs", "Co"):  # undefined, surrogates, private
             return False
         return True
 
-    characters = [chr(cpt) for cpt in range(1, MAX_CODEPOINTS) if _valid(cpt)]
+    characters = [chr(cpt) for cpt in range(0, MAX_CODEPOINTS) if _valid(cpt)]
 
     yield from characters
 
@@ -273,11 +368,11 @@ def generator_random_unicodes(iterations=100) -> Iterator[str]:
         yield "".join(text)
 
 
-def generator_random_vocab_chars(vocab: list[str], iterations=100) -> Iterator[str]:
+def generator_random_vocab_chars(tokenizer: TokenizerGroundtruth, iterations=100) -> Iterator[str]:
     """Brute force random text with vocab characters"""
 
     vocab_chars = set()
-    for word in vocab:
+    for word in tokenizer.vocab:
         vocab_chars.update(word)
     vocab_chars = list(sorted(vocab_chars))
 
@@ -288,10 +383,10 @@ def generator_random_vocab_chars(vocab: list[str], iterations=100) -> Iterator[s
         yield "".join(text)
 
 
-def generator_random_vocab_words(vocab: list[str], iterations=100) -> Iterator[str]:
+def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100) -> Iterator[str]:
     """Brute force random text from vocab words"""
 
-    vocab = [w.strip() for w in vocab]
+    vocab = [w.strip() for w in tokenizer.vocab]
     yield from vocab
 
     rand = random.Random()
@@ -307,7 +402,7 @@ def generator_random_vocab_words(vocab: list[str], iterations=100) -> Iterator[s
         yield "".join(text)
 
 
-def compare_tokenizers(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
+def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
 
     def find_first_mismatch(ids1: list[int], ids2: list[int]):
         for i, (a, b) in enumerate(zip(ids1, ids2)):
@@ -317,34 +412,67 @@ def compare_tokenizers(func_tokenize1: Callable, func_tokenize2: Callable, gener
             return -1
         return min(len(ids1), len(ids2))
 
-    t_tokenizer1 = 0
-    t_tokenizer2 = 0
+    def check_detokenizer(text: str, text1: str, text2: str) -> bool:
+        if text1 == text2:  # equal to TokenizerGroundtruth?
+            return True
+        # equal to source text?
+        if tokenizer1.add_bos_token:  # remove BOS
+            if text2.startswith(tokenizer1.bos_token):
+                text2 = text2[len(tokenizer1.bos_token):]
+        if tokenizer1.add_eos_token:  # remove EOS
+            if text2.endswith(tokenizer1.eos_token):
+                text2 = text2[:-len(tokenizer1.eos_token)]
+        return text == text2
+
+    t_encode1 = 0
+    t_encode2 = 0
+    t_decode1 = 0
+    t_decode2 = 0
     t_start = time.perf_counter()
-    num_errors = 10
+    encode_errors = 0
+    decode_errors = 0
+    MAX_ERRORS = 10
 
     logger.info("%s: %s" % (generator.__name__, "ini"))
     for text in generator:
+        # print(repr(text), text.encode())
         # print(repr(text), hex(ord(text[0])), text.encode())
         t0 = time.perf_counter()
-        ids1 = func_tokenize1(text)
+        ids1 = tokenizer1.encode(text)
         t1 = time.perf_counter()
-        ids2 = func_tokenize2(text)
+        ids2 = tokenizer2.encode(text)
         t2 = time.perf_counter()
-        t_tokenizer1 += t1 - t0
-        t_tokenizer2 += t2 - t1
-        if ids1 != ids2:
+        text1 = tokenizer1.decode(ids1)
+        t3 = time.perf_counter()
+        text2 = tokenizer2.decode(ids1)
+        t4 = time.perf_counter()
+        t_encode1 += t1 - t0
+        t_encode2 += t2 - t1
+        t_decode1 += t3 - t2
+        t_decode2 += t4 - t3
+        if encode_errors < MAX_ERRORS and ids1 != ids2:
             i = find_first_mismatch(ids1, ids2)
             ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
             ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
-            logger.error(" TokenIDs: " + str(ids1))
-            logger.error(" Expected: " + str(ids2))
+            logger.error(" Expected: " + str(ids1))
+            logger.error("   Result: " + str(ids2))
+            encode_errors += 1
+            logger.error(f" {encode_errors=}")
+        if decode_errors < MAX_ERRORS and not check_detokenizer(text, text1, text2):
+            i = find_first_mismatch(text1, text2)
+            text1 = list(text1[max(0, i - 2) : i + 5 + 1])
+            text2 = list(text2[max(0, i - 2) : i + 5 + 1])
+            logger.error(" Expected: " + " ".join(hex(ord(x)) for x in text1))
+            logger.error("   Result: " + " ".join(hex(ord(x)) for x in text2))
+            decode_errors += 1
+            logger.error(f" {decode_errors=}")
+        if encode_errors >= MAX_ERRORS and decode_errors >= MAX_ERRORS:
+            logger.error(f" EXIT: {encode_errors=} {decode_errors=}")
             # raise Exception()
-            num_errors += 1
-            if num_errors > 10:
-                break
+            break
 
     t_total = time.perf_counter() - t_start
-    logger.info("%s: end,  tok1: %.3f  tok2: %.3f  total: %.3f" % (generator.__name__, t_tokenizer1, t_tokenizer2, t_total))
+    logger.info(f"{generator.__name__}: end,  {t_encode1=:.3f} {t_encode2=:.3f}  {t_decode1=:.3f} {t_decode2=:.3f}  {t_total=:.3f}")
 
 
 def main(argv: list[str] = None):
@@ -357,74 +485,76 @@ def main(argv: list[str] = None):
     logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO)
     logger.info(f"VOCABFILE: '{args.vocab_file}'")
 
-    model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
-    tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
-
-    def func_tokenize1(text: str):
-        return model.tokenize(text, add_special=True, parse_special=True)
-
-    def func_tokenize2(text: str):
-        return tokenizer.encode(text, add_special_tokens=True)
+    tokenizer1 = TokenizerGroundtruth(args.dir_tokenizer)
+    tokenizer2 = TokenizerLlamaCpp(args.vocab_file)
 
-    ids = func_tokenize2("a")
-    assert 1 <= len(ids) <= 3
-    add_bos_token = len(ids) > 1 and tokenizer.bos_token_id == ids[0]
-    add_eos_token = len(ids) > 1 and tokenizer.eos_token_id == ids[-1]
-    tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", add_bos_token)
-    tokenizer.add_eos_token = getattr(tokenizer, "add_eos_token", add_eos_token)
+    # compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text())
+    # compare_tokenizers(tokenizer1, tokenizer2, generator_custom_text_edge_cases())
+    compare_tokenizers(tokenizer1, tokenizer2, generator_ascii_lr_strip())
+    compare_tokenizers(tokenizer1, tokenizer2, generator_apostrophe())
+    compare_tokenizers(tokenizer1, tokenizer2, generator_unicodes())
+    compare_tokenizers(tokenizer1, tokenizer2, generator_vocab_words(tokenizer1))
+    compare_tokenizers(tokenizer1, tokenizer2, generator_added_lr_strip(tokenizer1))
+    # compare_tokenizers(tokenizer1, tokenizer2, generator_random_added_tokens(tokenizer1, 10_000))
+    # compare_tokenizers(tokenizer1, tokenizer2, generator_random_chars(10_000))
+    # compare_tokenizers(tokenizer1, tokenizer2, generator_random_unicodes(10_000))
+    # compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_chars(tokenizer1, 10_000))
+    # compare_tokenizers(tokenizer1, tokenizer2, generator_random_vocab_words(tokenizer1, 5_000))
 
-    vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
-
-    compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text())
-    compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
-    compare_tokenizers(func_tokenize1, func_tokenize2, generator_unicodes())
-    compare_tokenizers(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
-    compare_tokenizers(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer))
-    compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000))
-    compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
-    compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_unicodes(10_000))
-    compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
-    compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000))
-
-    model.free()
+    tokenizer2.model.free()
 
 
 if __name__ == "__main__":
     # main()
 
+    if True:
+        logging.basicConfig(
+            level    = logging.DEBUG,
+            format   = "%(asctime)s.%(msecs)03d %(name)s %(levelname)s %(message)s",
+            datefmt  = "%Y-%m-%d %H:%M:%S",
+            filename = logger.name + ".log",
+            filemode = "a"
+        )
     logging.basicConfig(
         level    = logging.DEBUG,
-        format   = "%(asctime)s.%(msecs)03d %(name)s %(levelname)s %(message)s",
-        datefmt  = "%Y-%m-%d %H:%M:%S",
-        filename = logger.name + ".log",
-        filemode = "a"
+        format   = "%(levelname)s %(message)s",
     )
 
     path_tokenizers   = "./models/tokenizers/"
     path_vocab_format = "./models/ggml-vocab-%s.gguf"
 
-    # import os
-    # tokenizers = os.listdir(path_tokenizers)
     tokenizers = [
-        # "llama-spm",   # SPM
-        # "phi-3",       # SPM
-        # "bert-bge",    # WPM
-        # "jina-v2-en",  # WPM
-        "gpt-2",          # BPE
+        "llama-spm",      # SPM
+        "phi-3",          # SPM
+        "gemma",          # SPM
+        "gemma-2",        # SPM
+        "baichuan",       # SPM
+        "bert-bge",       # WPM
+        "jina-v2-en",     # WPM
         "llama-bpe",      # BPE
+        "phi-2",          # BPE
+        "deepseek-llm",   # BPE
+        "deepseek-coder", # BPE
         "falcon",         # BPE
+        "mpt",            # BPE
         "starcoder",      # BPE
+        "gpt-2",          # BPE
+        "stablelm2",      # BPE
+        "refact",         # BPE
+        "qwen2",          # BPE
+        "olmo",           # BPE
         "jina-v2-es",     # BPE
         "jina-v2-de",     # BPE
-        "jina-v2-code",   # BPE
         "smaug-bpe",      # BPE
-        "phi-2",          # BPE
-        "deepseek-coder", # BPE
-        "deepseek-llm",   # BPE
+        "poro-chat",      # BPE
+        "jina-v2-code",   # BPE
+        "viking",         # BPE
+        "jais",           # BPE
     ]
 
+    logger.info("=" * 50)
     for tokenizer in tokenizers:
-        logger.info("=" * 50)
+        logger.info("-" * 50)
         logger.info(f"TOKENIZER: '{tokenizer}'")
         vocab_file = path_vocab_format % tokenizer
         dir_tokenizer = path_tokenizers + "/" + tokenizer