]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix whitespace escaping in tokenizer (#2724)
authorgoerch <redacted>
Tue, 22 Aug 2023 21:10:42 +0000 (23:10 +0200)
committerGitHub <redacted>
Tue, 22 Aug 2023 21:10:42 +0000 (00:10 +0300)
llama.cpp
tests/test-tokenizer-0.cpp
tests/test-tokenizer-1.cpp

index 6abdc44f2a0625da2593dea04729964579e3853a..6c5da130926fcbbb8d2a97418d84d55f99ca78ca 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2253,18 +2253,11 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
 }
 
 static std::string llama_escape_whitespace(const std::string& text) {
-    std::string result;
-    bool escaping = false;
-    result += "\xe2\x96\x81";
+    std::string result = "\xe2\x96\x81";
     for (size_t offs = 0; offs < text.length(); ++offs) {
         if (text[offs] == ' ') {
-            if (!escaping) {
-                result += "\xe2\x96\x81";
-                escaping = true;
-            }
-        }
-        else {
-            escaping = false;
+            result += "\xe2\x96\x81";
+        } else {
             result += text[offs];
         }
     }
index 81764565b5710bdbddb85a5f6680f9b4aca263e8..f3ee851a3880ca029ef41458ea868cbc4a86776d 100644 (file)
@@ -17,6 +17,8 @@ static std::string unescape_whitespace(llama_context* ctx, const std::vector<lla
 static const std::map<std::string, std::vector<llama_token>> & k_tests() {
     static std::map<std::string, std::vector<llama_token>> _k_tests = {
         { " ",                      {1,    259, }, },
+        { "  ",                     { 1,    1678, }, },
+        { "   ",                    { 1,     268, }, },
         { "\t",                     { 1,    29871,   12, }, },
         { "\n",                     { 1,    29871,   13, }, },
         { "\t\n",                   { 1,    29871,   12,     13, }, },
@@ -38,6 +40,12 @@ static const std::map<std::string, std::vector<llama_token>> & k_tests() {
                 243,    162,    155,    185,  30722,    243,    162,    143,    174,  30598,
                 313,  20787,    953,   3848,    275,  16125,    630,  29897,  29871,  31681,
                 313,   6194,    953,  29877,   2397,    393,    756,    967,   1914,   5993,  29897, }, },
+        { "Hello",                  { 1,    15043 }, },
+        { " Hello",                 { 1,    29871,  15043 }, },
+        { "  Hello",                { 1,    259,    15043 }, },
+        { "   Hello",               { 1,    1678,   15043 }, },
+        { "    Hello",              { 1,    268,    15043 }, },
+        { "    Hello\n    Hello",   { 1,    268,    15043,  13,     1678,   15043 }, },
     };
 
     return _k_tests;
@@ -106,7 +114,8 @@ int main(int argc, char **argv) {
 
         if (!correct) {
             fprintf(stderr, "%s : failed test:    '%s'\n", __func__, test_kv.first.c_str());
-            fprintf(stderr, "%s : detokenized to: '%s'\n", __func__, unescape_whitespace(ctx, test_kv.second).c_str());
+            fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
+                unescape_whitespace(ctx, res).c_str(), unescape_whitespace(ctx, test_kv.second).c_str());
             fprintf(stderr, "%s : expected tokens: ", __func__);
             for (const auto & t : test_kv.second) {
                 fprintf(stderr, "%6d, ", t);
index d8db7cd96eaa45a8563e17c07bb1cec50a2aec2e..993d17f1833d30779671defd16deb3aa6b502fe9 100644 (file)
 #include <locale>
 
 static std::string escape_whitespace(const std::string& text) {
-    std::string result;
-    bool escaping = false;
-    result += "\xe2\x96\x81";
+    std::string result = "\xe2\x96\x81";
     for (size_t offs = 0; offs < text.length(); ++offs) {
         if (text[offs] == ' ') {
-            if (!escaping) {
-                result += "\xe2\x96\x81";
-                escaping = true;
-            }
-        }
-        else {
-            escaping = false;
+            result += "\xe2\x96\x81";
+        } else {
             result += text[offs];
         }
     }