From: Ravindra Marella Date: Sun, 18 Jun 2023 07:37:09 +0000 (+0530) Subject: starcoder : add support for starchat special tokens (#246) X-Git-Tag: upstream/0.0.1642~1409 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=e456108433017d5586b35fd36ce781b4c3aed631;p=pkg%2Fggml%2Fsources%2Fggml starcoder : add support for starchat special tokens (#246) * starcoder : add support for starchat special tokens * examples : fix `gpt_tokenize()` for special tokens --- diff --git a/examples/common.cpp b/examples/common.cpp index db90742d..cf1769bd 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -232,37 +232,53 @@ std::wstring convert_to_wstring(const std::string & input) { return converter.from_bytes(input); } +void gpt_split_words(std::string str, std::vector& words) { + const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; + const std::regex re(pattern); + std::smatch m; + + while (std::regex_search(str, m, re)) { + for (auto x : m) { + words.push_back(x); + } + str = m.suffix(); + } +} + std::vector gpt_tokenize(const gpt_vocab & vocab, const std::string & text) { std::vector words; // first split the text into words { std::string str = text; - std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; // Generate the subpattern from the special_tokens vector if it's not empty if (!vocab.special_tokens.empty()) { + const std::regex escape(R"([\[\\\^\$\.\|\?\*\+\(\)\{\}])"); std::string special_tokens_subpattern; for (const auto & token : vocab.special_tokens) { if (!special_tokens_subpattern.empty()) { special_tokens_subpattern += "|"; } - special_tokens_subpattern += token; + special_tokens_subpattern += std::regex_replace(token, escape, R"(\$&)"); } - // Modify the regex pattern with the generated special tokens subpattern - pat = special_tokens_subpattern + "|" + pat; - } - - std::regex re(pat); - std::smatch m; - - while (std::regex_search(str, m, re)) { - for (auto x : m) { - words.push_back(x); + std::regex re(special_tokens_subpattern); + std::smatch m; + // Split the text by special tokens. + while (std::regex_search(str, m, re)) { + // Split the substrings in-between special tokens into words. + gpt_split_words(m.prefix(), words); + // Add matched special tokens as words. + for (auto x : m) { + words.push_back(x); + } + str = m.suffix(); } - str = m.suffix(); + // Remaining text without special tokens will be handled below. } + + gpt_split_words(str, words); } // find the longest token that forms each word in words: diff --git a/examples/common.h b/examples/common.h index 0381802e..0431d5a8 100644 --- a/examples/common.h +++ b/examples/common.h @@ -66,6 +66,8 @@ std::string convert_to_utf8(const std::wstring & input); std::wstring convert_to_wstring(const std::string & input); +void gpt_split_words(std::string str, std::vector& words); + // split text into tokens // // ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53 diff --git a/examples/starcoder/main.cpp b/examples/starcoder/main.cpp index 67e50782..de3b8a50 100644 --- a/examples/starcoder/main.cpp +++ b/examples/starcoder/main.cpp @@ -139,6 +139,18 @@ bool starcoder_model_load(const std::string & fname, starcoder_model & model, gp // if (i < 10) fprintf(stderr, "%.s: vocab[%d] = '%s'\n", __func__, i, word.c_str()); } + + // Add StarChat special tokens. + for (const std::string & token : { + "<|system|>", + "<|user|>", + "<|assistant|>", + "<|end|>", + }) { + if (vocab.token_to_id.find(token) != vocab.token_to_id.end()) { + vocab.add_special_token(token); + } + } } // for the big tensors, we have the option to store the data in 16-bit floats or quantized @@ -781,6 +793,15 @@ int main(int argc, char ** argv) { } printf("\n\n"); + // Handle StarChat "<|end|>" token. + gpt_vocab::id starchat_end_token = -1; + { + const auto it = vocab.token_to_id.find("<|end|>"); + if (it != vocab.token_to_id.end()) { + starchat_end_token = it->second; + } + } + // submit the input prompt token-by-token // this reduces the memory usage during inference, at the cost of a bit of speed at the beginning std::vector embd; @@ -850,6 +871,10 @@ int main(int argc, char ** argv) { else if (embd.back() == 0) { //TODO: this is only for starcoder break; } + // Handle StarChat "<|end|>" token. + else if (embd.back() == starchat_end_token) { + break; + } } // report timing