#include <string>
#include <thread>
#include <vector>
+#include <regex>
#define USE_FLASH_ATTN
//#define USE_FLASH_FF
return true;
}
+// split text into tokens
+//
+// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
+//
+// Regex (Python):
+// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
+//
+// Regex (C++):
+// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
+//
+static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, const std::string & text) {
+ std::vector<std::string> words;
+
+ // first split the text into words
+ {
+ std::string str = text;
+ std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
+
+ std::regex re(pat);
+ std::smatch m;
+
+ while (std::regex_search(str, m, re)) {
+ for (auto x : m) {
+ words.push_back(x);
+ }
+ str = m.suffix();
+ }
+ }
+
+ // find the longest tokens that form the words:
+ std::vector<whisper_vocab::id> tokens;
+ for (const auto & word : words) {
+ if (word.size() == 0) continue;
+
+ int i = 0;
+ int n = word.size();
+ while (i < n) {
+ int j = n;
+ while (j > i) {
+ auto it = vocab.token_to_id.find(word.substr(i, j-i));
+ if (it != vocab.token_to_id.end()) {
+ tokens.push_back(it->second);
+ i = j;
+ break;
+ }
+ --j;
+ }
+ if (i == n) {
+ break;
+ }
+ if (j == i) {
+ auto sub = word.substr(i, 1);
+ if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
+ tokens.push_back(vocab.token_to_id.at(sub));
+ } else {
+ fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
+ }
+ ++i;
+ }
+ }
+ }
+
+ return tokens;
+}
+
//
// interface implementation
//
return res;
}
+int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
+ const auto res = tokenize(ctx->vocab, text);
+
+ if (res.size() > n_max_tokens) {
+ fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
+ return -1;
+ }
+
+ for (int i = 0; i < res.size(); i++) {
+ tokens[i] = res[i];
+ }
+
+ return res.size();
+}
+
int whisper_lang_id(const char * lang) {
if (!g_lang.count(lang)) {
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
+ // Convert the provided text into tokens.
+ // The tokens pointer must be large enough to hold the resulting tokens.
+ // Returns the number of tokens on success, no more than n_max_tokens
+ // Returns -1 on failure
+ // TODO: not sure if correct
+ WHISPER_API int whisper_tokenize(
+ struct whisper_context * ctx,
+ const char * text,
+ whisper_token * tokens,
+ int n_max_tokens);
+
// Return the id of the specified language, returns -1 if not found
WHISPER_API int whisper_lang_id(const char * lang);