]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Per token attributes (#7685)
authorjaime-m-p <redacted>
Tue, 4 Jun 2024 07:17:17 +0000 (09:17 +0200)
committerGitHub <redacted>
Tue, 4 Jun 2024 07:17:17 +0000 (09:17 +0200)
* Add per token attributes enum
* Using phi-3 for testing 'rstrip'
* Using jina-v2 for testing 'lstrip'
* Brute force test for 'lstrip' and 'rstrip'
* Implement 'rstrip' and 'lstrip'
* Update phi-3 GGUF file (obsolete since 917dc8c)
* Replace llama_token_type with llama_token_attribs

llama.cpp
llama.h
models/ggml-vocab-phi-3.gguf
tests/test-tokenizer-random.py

index b0bf70e20cfc952ad5852d01d422540759278134..a3e944874cf807fb0c4a521f7cbdce89e0cc50ae 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2149,12 +2149,12 @@ struct llama_control_vector {
 struct llama_vocab {
     using id    = int32_t;
     using token = std::string;
-    using ttype = llama_token_type;
+    using tattr = llama_token_attr;
 
     struct token_data {
         token text;
         float score;
-        ttype type;
+        tattr attr;
     };
 
     enum llama_vocab_type     type     = LLAMA_VOCAB_TYPE_SPM;
@@ -4750,7 +4750,20 @@ static void llm_load_vocab(
         auto & token_data = vocab.id_to_token[i];
         token_data.text  = std::move(word);
         token_data.score = scores ? scores[i] : 0.0f;
-        token_data.type  = toktypes ? (llama_token_type) toktypes[i] : LLAMA_TOKEN_TYPE_NORMAL;
+        token_data.attr  = LLAMA_TOKEN_ATTR_NORMAL;
+
+        if (toktypes) {  //TODO: remove, required until per token attributes are available from GGUF file
+            switch(toktypes[i]) {
+                case LLAMA_TOKEN_TYPE_UNKNOWN:      token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN;      break;
+                case LLAMA_TOKEN_TYPE_UNUSED:       token_data.attr = LLAMA_TOKEN_ATTR_UNUSED;       break;
+                case LLAMA_TOKEN_TYPE_NORMAL:       token_data.attr = LLAMA_TOKEN_ATTR_NORMAL;       break;
+                case LLAMA_TOKEN_TYPE_CONTROL:      token_data.attr = LLAMA_TOKEN_ATTR_CONTROL;      break;
+                case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break;
+                case LLAMA_TOKEN_TYPE_BYTE:         token_data.attr = LLAMA_TOKEN_ATTR_BYTE;         break;
+                case LLAMA_TOKEN_TYPE_UNDEFINED:    token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED;    break;
+                default:                            token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED;    break;
+            }
+        }
     }
     GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
 
@@ -4841,7 +4854,7 @@ static void llm_load_vocab(
     // build special tokens cache
     {
         for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
-            if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
+            if (!(vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL)) {
                 vocab.cache_special_tokens.push_back(id);
             }
         }
@@ -4871,6 +4884,59 @@ static void llm_load_vocab(
 
         LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
     }
+
+    // Handle per token attributes
+    //NOTE: Each model customizes per token attributes.
+    //NOTE: Per token attributes are missing from the GGUF file.
+    //TODO: Extract attributes from GGUF file.
+    {
+        auto _contains_any = [] (const std::string &str, const std::vector<std::string> &substrs) -> bool {
+            for (auto substr : substrs) {
+                if (str.find(substr) < std::string::npos) {
+                    return true;
+                }
+            }
+            return false;
+        };
+
+        auto _set_tokenid_attr = [&] (const llama_vocab::id id, llama_token_attr attr, bool value) {
+            uint32_t current = vocab.id_to_token.at(id).attr;
+            current = value ? (current | attr) : (current & ~attr);
+            vocab.id_to_token[id].attr = (llama_token_attr) current;
+        };
+
+        auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
+            _set_tokenid_attr(vocab.token_to_id.at(token), attr, value);
+        };
+
+        std::string model_name;
+        std::string tokenizer_pre;
+
+        ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
+        ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
+
+        // model name to lowercase
+        std::transform(model_name.begin(), model_name.end(), model_name.begin(),
+            [] (const std::string::value_type x) {
+                return std::tolower(x);
+            }
+        );
+
+        // set attributes by model/tokenizer name
+        if (_contains_any(tokenizer_pre, {"jina-v2-es", "jina-v2-de"})) {
+            _set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
+        } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
+            for (auto id : vocab.cache_special_tokens) {
+                _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (auto token : {"</s>"}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (auto token : {"<unk>", "<s>", "<|endoftext|>"}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
+            }
+        }
+    }
 }
 
 static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
@@ -12620,27 +12686,27 @@ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
 
 static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) {
     GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL;
+    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
 }
 
 static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) {
     GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNKNOWN;
+    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
 }
 
 static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) {
     GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL;
+    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
 }
 
 static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
     GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE;
+    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
 }
 
 static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) {
     GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED;
+    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
 }
 
 static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
@@ -13258,7 +13324,8 @@ struct fragment_buffer_variant {
 static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
     // for each special token
     for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
-        const auto & special_token = vocab.id_to_token[special_id].text;
+        const auto & data = vocab.id_to_token[special_id];
+        const auto & special_token = data.text;
 
         // for each text fragment
         std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
@@ -13295,13 +13362,22 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                     if (match > raw_text_base_offset) {
                         // left
                         const int64_t left_reminder_offset = raw_text_base_offset + 0;
-                        const int64_t left_reminder_length = match - raw_text_base_offset;
-                        buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
+                        int64_t left_reminder_length = match - raw_text_base_offset;
+
+                        if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) {
+                            while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) {
+                                left_reminder_length--;
+                            }
+                        }
+
+                        if (left_reminder_length > 0) {
+                            buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
+                            it++;
+                        }
 
 #ifdef PRETOKENIZERDEBUG
                         LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
 #endif
-                        it++;
                     }
 
                     // special token
@@ -13310,16 +13386,25 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
 
                     // right
                     if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
-                        const int64_t right_reminder_offset = match + special_token.length();
-                        const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
-                        buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
+                        int64_t right_reminder_offset = match + special_token.length();
+                        int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
+
+                        if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) {
+                            while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
+                                right_reminder_offset++;
+                                right_reminder_length--;
+                            }
+                        }
+
+                        if (right_reminder_length > 0) {
+                            buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
+                            it++;
+                        }
 
 #ifdef PRETOKENIZERDEBUG
                         LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
 #endif
 
-                        it++;
-
                         if (source == 0) {
                             buffer.erase_after(buffer.before_begin());
                         } else {
@@ -13365,9 +13450,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
                 // tokenizer.encode('', add_special_tokens=True)  returns [1]
                 // tokenizer.encode('', add_special_tokens=False) returns []
 
-                static const bool rtrim = true;  //TODO: as param
                 bool is_prev_special = false;
-                bool special_token_rtrim = false;
 
                 if (add_special && vocab.special_add_bos != 0) {
                     GGML_ASSERT(vocab.special_bos_id != -1);
@@ -13377,25 +13460,8 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
 
                 for (const auto & fragment : fragment_buffer) {
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        // without adding this leading whitespace, we do not get the same results as the original tokenizer
-
-                        // TODO: It's likely possible to get rid of this string copy entirely
-                        //  by modifying llm_tokenizer_x to operate with string offsets like pre-tokenizer
-                        //  and passing 'add space prefix' as bool argument
-                        //
                         auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
 
-                        if (special_token_rtrim) {
-                            size_t num_whitespaces = 0;
-                            while (isspace(raw_text[num_whitespaces])) {
-                                num_whitespaces++;
-                            }
-                            if (num_whitespaces == raw_text.size()) {
-                                continue; // skip if all whitespaces
-                            }
-                            raw_text = raw_text.substr(num_whitespaces);
-                        }
-
                         if (vocab.add_space_prefix) {
                             if (!output.size() || is_prev_special) {  // prefix with space if first token
                                 raw_text = " " + raw_text;
@@ -13411,11 +13477,6 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
                     } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
                         output.push_back(fragment.token);
                         is_prev_special = true;
-                        // phi-3 special tokens without rtrim, works fine for llama-spm too
-                        special_token_rtrim = rtrim
-                            && fragment.token != vocab.special_bos_id
-                            && fragment.token != vocab.special_unk_id
-                            && fragment.token != vocab.special_eos_id;
                     }
                 }
 
@@ -18221,9 +18282,9 @@ float llama_token_get_score(const struct llama_model * model, llama_token token)
     return model->vocab.id_to_token[token].score;
 }
 
-llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token) {
+llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
     GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return model->vocab.id_to_token[token].type;
+    return model->vocab.id_to_token[token].attr;
 }
 
 bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
diff --git a/llama.h b/llama.h
index 95105c28e5e42b038352695d51b51c80d4935350..a78ccdaf557d05c3e222ee53e5c4737eda644eee 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -97,7 +97,7 @@ extern "C" {
         LLAMA_ROPE_TYPE_GLM  =  4,
     };
 
-    enum llama_token_type {
+    enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file
         LLAMA_TOKEN_TYPE_UNDEFINED    = 0,
         LLAMA_TOKEN_TYPE_NORMAL       = 1,
         LLAMA_TOKEN_TYPE_UNKNOWN      = 2,
@@ -107,6 +107,20 @@ extern "C" {
         LLAMA_TOKEN_TYPE_BYTE         = 6,
     };
 
+    enum llama_token_attr {
+        LLAMA_TOKEN_ATTR_UNDEFINED    = 0,
+        LLAMA_TOKEN_ATTR_UNKNOWN      = 1 <<  1,
+        LLAMA_TOKEN_ATTR_UNUSED       = 1 <<  2,
+        LLAMA_TOKEN_ATTR_NORMAL       = 1 <<  3,
+        LLAMA_TOKEN_ATTR_CONTROL      = 1 <<  4,  // SPECIAL?
+        LLAMA_TOKEN_ATTR_USER_DEFINED = 1 <<  5,
+        LLAMA_TOKEN_ATTR_BYTE         = 1 <<  6,
+        LLAMA_TOKEN_ATTR_NORMALIZED   = 1 <<  7,
+        LLAMA_TOKEN_ATTR_LSTRIP       = 1 <<  8,
+        LLAMA_TOKEN_ATTR_RSTRIP       = 1 <<  9,
+        LLAMA_TOKEN_ATTR_SINGLE_WORD  = 1 << 10,
+    };
+
     // model file types
     enum llama_ftype {
         LLAMA_FTYPE_ALL_F32              = 0,
@@ -821,7 +835,7 @@ extern "C" {
 
     LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
 
-    LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
+    LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token);
 
     // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
     LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
index f8022a385e4aa48ca10d40d0f079a25365a7be78..745be416a798a1e7a2effaa935bc903edaaf303d 100644 (file)
Binary files a/models/ggml-vocab-phi-3.gguf and b/models/ggml-vocab-phi-3.gguf differ
index ec1b2837cfab5c697e53cd6ec10628d00469f6a8..52f589511e47032bbcf52b8e645afc593a56ae61 100644 (file)
@@ -156,17 +156,39 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
         '<s>a',       # Phi-3 fail
         '<unk><|endoftext|><s>',  # Phi-3 fail
         'a\na',       # TODO: Bert fail
+        'a </s> b',   # rstrip phi-3
+        'a <mask> b', # lstrip jina-v2
     ]
 
 
-def generator_random_special_tokens(tokenizer, iterations=100) -> Iterator[str]:
-    special_tokens = set(tokenizer.all_special_tokens)
-    special_tokens.update([" ", "\n", "\t", "-", "!", "one", "1", "<s>", "</s>"])
-    special_tokens = list(sorted(special_tokens))
+def generator_vocab_words(vocab: list[str]) -> Iterator[str]:
+    """Brute force check all vocab words"""
+    yield from vocab
+
+
+def generator_added_lr_strip(tokenizer) -> Iterator[str]:
+    WHITESPACES = ["", " ", "  ", "    "]
+    special_tokens = list(tokenizer.all_special_tokens)
+    added_tokens   = list(tokenizer.added_tokens_encoder)
+    all_tokens     = list(sorted(set(special_tokens + added_tokens)))
+    for token in all_tokens:
+        for lstrip in WHITESPACES:
+            for rstrip in WHITESPACES:
+                yield lstrip + token + rstrip
+                yield "a" + lstrip + token + rstrip
+                yield lstrip + token + rstrip + "z"
+                yield "a" + lstrip + token + rstrip + "z"
+
+
+def generator_random_added_tokens(tokenizer, iterations=100) -> Iterator[str]:
+    special_tokens = list(tokenizer.all_special_tokens)
+    added_tokens   = list(tokenizer.added_tokens_encoder)
+    separations    = [" ", "\n", "\t", "-", "!", "one", "1", "<s>", "</s>"]
+    all_tokens     = list(sorted(set(special_tokens + added_tokens + separations)))
     rand = random.Random()
     for m in range(iterations):
         rand.seed(m)
-        words = rand.choices(special_tokens, k=500)
+        words = rand.choices(all_tokens, k=500)
         if words[0] == tokenizer.bos_token:  # skip spam warning of double BOS
             while len(words) > 1 and words[1] == tokenizer.bos_token:  # leave one starting BOS
                 words.pop(0)
@@ -175,11 +197,6 @@ def generator_random_special_tokens(tokenizer, iterations=100) -> Iterator[str]:
         yield "".join(words)
 
 
-def generator_vocab_words(vocab: list[str]) -> Iterator[str]:
-    """Brute force check all vocab words"""
-    yield from vocab
-
-
 def generator_random_chars(iterations=100) -> Iterator[str]:
     """Brute force random text with simple characters"""
 
@@ -274,8 +291,8 @@ def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, g
         ids2 = func_tokenize2(text)
         if ids1 != ids2:
             i = find_first_mismatch(ids1, ids2)
-            ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
-            ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1]
+            ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
+            ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
             logger.info(" TokenIDs: " + str(ids1))
             logger.info(" Expected: " + str(ids2))
             raise Exception()
@@ -309,8 +326,9 @@ def main(argv: list[str] = None):
     vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
     test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())
     test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
-    test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_special_tokens(tokenizer, 10_000))
     test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
+    test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer))
+    test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000))
     test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
     test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
     test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000))
@@ -322,14 +340,14 @@ def main(argv: list[str] = None):
 if __name__ == "__main__":
     # main()
 
-    path_tokenizers = "./models/tokenizers/"
+    path_tokenizers   = "./models/tokenizers/"
     path_vocab_format = "./models/ggml-vocab-%s.gguf"
 
     # import os
     # tokenizers = os.listdir(path_tokenizers)
     tokenizers = [
-        "llama-spm",   # SPM
-        "phi-3",       # SPM
+        "llama-spm",   # SPM
+        "phi-3",       # SPM
         "jina-v2-en",  # WPM
         "bert-bge",    # WPM
     ]