]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
starcoder : add support for starchat special tokens (#246)
authorRavindra Marella <redacted>
Sun, 18 Jun 2023 07:37:09 +0000 (13:07 +0530)
committerGitHub <redacted>
Sun, 18 Jun 2023 07:37:09 +0000 (10:37 +0300)
* starcoder : add support for starchat special tokens

* examples : fix `gpt_tokenize()` for special tokens

examples/common.cpp
examples/common.h
examples/starcoder/main.cpp

index db90742d07703085513cab06655d1495475fa166..cf1769bdbaa57591fec07de3832d6fe4b79a9411 100644 (file)
@@ -232,37 +232,53 @@ std::wstring convert_to_wstring(const std::string & input) {
     return converter.from_bytes(input);
 }
 
+void gpt_split_words(std::string str, std::vector<std::string>& words) {
+    const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
+    const std::regex re(pattern);
+    std::smatch m;
+
+    while (std::regex_search(str, m, re)) {
+        for (auto x : m) {
+            words.push_back(x);
+        }
+        str = m.suffix();
+    }
+}
+
 std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
     std::vector<std::string> words;
 
     // first split the text into words
     {
         std::string str = text;
-        std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
 
         // Generate the subpattern from the special_tokens vector if it's not empty
         if (!vocab.special_tokens.empty()) {
+            const std::regex escape(R"([\[\\\^\$\.\|\?\*\+\(\)\{\}])");
             std::string special_tokens_subpattern;
             for (const auto & token : vocab.special_tokens) {
                 if (!special_tokens_subpattern.empty()) {
                     special_tokens_subpattern += "|";
                 }
-                special_tokens_subpattern += token;
+                special_tokens_subpattern += std::regex_replace(token, escape, R"(\$&)");
             }
 
-            // Modify the regex pattern with the generated special tokens subpattern
-            pat = special_tokens_subpattern + "|" + pat;
-        }
-
-        std::regex re(pat);
-        std::smatch m;
-
-        while (std::regex_search(str, m, re)) {
-            for (auto x : m) {
-                words.push_back(x);
+            std::regex re(special_tokens_subpattern);
+            std::smatch m;
+            // Split the text by special tokens.
+            while (std::regex_search(str, m, re)) {
+                // Split the substrings in-between special tokens into words.
+                gpt_split_words(m.prefix(), words);
+                // Add matched special tokens as words.
+                for (auto x : m) {
+                    words.push_back(x);
+                }
+                str = m.suffix();
             }
-            str = m.suffix();
+            // Remaining text without special tokens will be handled below.
         }
+
+        gpt_split_words(str, words);
     }
 
     // find the longest token that forms each word in words:
index 0381802e69cd179e8111441b8e60c3ccc5d3a297..0431d5a876d7aaffd306359f702213a7f813e5ec 100644 (file)
@@ -66,6 +66,8 @@ std::string convert_to_utf8(const std::wstring & input);
 
 std::wstring convert_to_wstring(const std::string & input);
 
+void gpt_split_words(std::string str, std::vector<std::string>& words);
+
 // split text into tokens
 //
 // ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
index 67e507824c6c6c1d04b0e150c5fd894cea4e3610..de3b8a50109d186969be6ce3b29b272a3a827e88 100644 (file)
@@ -139,6 +139,18 @@ bool starcoder_model_load(const std::string & fname, starcoder_model & model, gp
 
             // if (i < 10) fprintf(stderr, "%.s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
         }
+
+        // Add StarChat special tokens.
+        for (const std::string & token : {
+                "<|system|>",
+                "<|user|>",
+                "<|assistant|>",
+                "<|end|>",
+            }) {
+            if (vocab.token_to_id.find(token) != vocab.token_to_id.end()) {
+                vocab.add_special_token(token);
+            }
+        }
     }
 
     // for the big tensors, we have the option to store the data in 16-bit floats or quantized
@@ -781,6 +793,15 @@ int main(int argc, char ** argv) {
     }
     printf("\n\n");
 
+    // Handle StarChat "<|end|>" token.
+    gpt_vocab::id starchat_end_token = -1;
+    {
+        const auto it = vocab.token_to_id.find("<|end|>");
+        if (it != vocab.token_to_id.end()) {
+            starchat_end_token = it->second;
+        }
+    }
+
     // submit the input prompt token-by-token
     // this reduces the memory usage during inference, at the cost of a bit of speed at the beginning
     std::vector<gpt_vocab::id> embd;
@@ -850,6 +871,10 @@ int main(int argc, char ** argv) {
         else if (embd.back() == 0) { //TODO: this is only for starcoder
             break;
         }
+        // Handle StarChat "<|end|>" token.
+        else if (embd.back() == starchat_end_token) {
+            break;
+        }
     }
 
     // report timing