]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Minor improvements in GPT2 tokenizer (#3567)
authorgoerch <redacted>
Tue, 10 Oct 2023 16:59:52 +0000 (18:59 +0200)
committerGitHub <redacted>
Tue, 10 Oct 2023 16:59:52 +0000 (18:59 +0200)
* Fixing minor bugs in bpe_gpt2_preprocess

* Don't add bos token in test

llama.cpp
tests/test-tokenizer-0-falcon.cpp
tests/test-tokenizer-0-falcon.py
tests/test-tokenizer-0-llama.cpp
tests/test-tokenizer-0-llama.py

index 4653c80232c5cd1bedcb4ae3816c5aacfcefe251..7ed8722376d2df6397a4f339915b1fd115669599 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -6342,7 +6342,6 @@ private:
         for (int i = 0; i < (int)text_utf.size(); i++) {
             const std::string & utf_char = text_utf[i];
             bool split_condition = false;
-            // const char* text_pos = raw_text_p + utf_char.seq_offset_bytes;
             int bytes_remain = text_utf.size() - i;
             // forward backward lookups
             const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
@@ -6368,9 +6367,9 @@ private:
             if (!split_condition && bytes_remain >= 3) {
                 // 're|'ve|'ll
                 if (utf_char == "\'" && (
-                    (utf_char_next == "r" || utf_char_next_next == "e") ||
-                    (utf_char_next == "v" || utf_char_next_next == "e") ||
-                    (utf_char_next == "l" || utf_char_next_next == "l"))
+                    (utf_char_next == "r" && utf_char_next_next == "e") ||
+                    (utf_char_next == "v" && utf_char_next_next == "e") ||
+                    (utf_char_next == "l" && utf_char_next_next == "l"))
                     ) {
                     split_condition = true;
                 }
@@ -6421,7 +6420,7 @@ private:
                 else if (collecting_special && (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
                     split_condition = true;
                 }
-                else if (collecting_whitespace_lookahead && codepoint_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE) {
+                else if (collecting_whitespace_lookahead && (codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) {
                     split_condition = true;
                 }
             }
index 0f3c50bce8ae9d82259a7f421a5364c59ec70655..a4e9d2b91272876feb2fdcc2b67ee3c50d884b86 100644 (file)
@@ -36,6 +36,8 @@ static const std::map<std::string, std::vector<llama_token>> & k_tests() {
         { "   Hello"              , {     258,  23090, }, },
         { "    Hello"             , {     466,  23090, }, },
         { "    Hello\n    Hello"  , {     466,  23090,    742,  23090, }, },
+        { "\n ="                  , {    1212,     40, }, },
+        { "' era"                 , {      18,   4932, }, },
     };
 
     return _k_tests;
@@ -155,7 +157,7 @@ int main(int argc, char **argv) {
 
         fprintf(stderr, "%s : text size: %zu\n", __func__, text.size());
 
-        const std::vector<llama_token> res = llama_tokenize(ctx, text, true);
+        const std::vector<llama_token> res = llama_tokenize(ctx, text, false);
 
         fprintf(stderr, "%s : tokens: %zu\n", __func__, res.size());
 
@@ -169,10 +171,8 @@ int main(int argc, char **argv) {
             }
 
             for (const auto & tok : res) {
-                ofs << tok << " ";
+                ofs << tok << " '" << llama_detokenize_bpe(ctx, std::vector<int>{tok}) << "'" << std::endl;
             }
-
-            ofs << "\n";
         }
 
         fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str());
index 9c8c1c7d1d3ca476192d42eb49ff8c10356f2664..cf65a3f65d72cc406f85b6187b1406d3a3871b40 100644 (file)
@@ -41,6 +41,8 @@ tests = [
         "   Hello",
         "    Hello",
         "    Hello\n    Hello",
+        "\n =",
+        "' era",
     ]
 
 for text in tests:
@@ -69,15 +71,14 @@ fname_tok = args.fname_tok
 if fname_tok:
     print('tokenizing file: ', fname_tok)
     fname_out = fname_tok + '.tok'
-    with open(fname_tok, 'r') as f:
+    with open(fname_tok, 'r', encoding='utf-8') as f:
         lines = f.readlines()
         s = ''.join(lines)
         res = tokenizer.encode(s)
         # write to file
-        with open(fname_out, 'w') as f:
+        with open(fname_out, 'w', encoding='utf-8') as f:
             for x in res:
-                f.write(str(x) + ' ')
-            f.write('\n')
+                f.write(str(x) + ' \'' + tokenizer.decode(x) + '\'\n')
         print('len(res): ', len(res))
         print('len(lines): ', len(lines))
     print('results written to: ', fname_out)
index 91c841f7bba8f690e2b75bf1ecd272526376bd6c..39c8d188c908614d635992a80b210790af6fda71 100644 (file)
@@ -174,10 +174,8 @@ int main(int argc, char **argv) {
             }
 
             for (const auto & tok : res) {
-                ofs << tok << " ";
+                ofs << tok << " '" << llama_detokenize_spm(ctx, std::vector<int>{tok}) << "'" << std::endl;
             }
-
-            ofs << "\n";
         }
 
         fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str());
index bc164ee296cb1d6ff1ee28469ee913b7c345e12e..078f680b165ca1a76045985e29ea267f47ee8f84 100644 (file)
@@ -81,15 +81,14 @@ fname_tok = args.fname_tok
 if fname_tok:
     print('tokenizing file: ', fname_tok)
     fname_out = fname_tok + '.tok'
-    with open(fname_tok, 'r') as f:
+    with open(fname_tok, 'r', encoding='utf-8') as f:
         lines = f.readlines()
         s = ''.join(lines)
         res = tokenizer.encode(s, add_bos=True)
         # write to file
-        with open(fname_out, 'w') as f:
+        with open(fname_out, 'w', encoding='utf-8') as f:
             for x in res:
-                f.write(str(x) + ' ')
-            f.write('\n')
+                f.write(str(x) + ' \'' + tokenizer.decode(x) + '\'\n')
         print('len(res): ', len(res))
         print('len(lines): ', len(lines))
     print('results written to: ', fname_out)