]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
tokenizer : special token handling (#3538)
authorstaviq <redacted>
Tue, 17 Oct 2023 15:11:01 +0000 (17:11 +0200)
committerGitHub <redacted>
Tue, 17 Oct 2023 15:11:01 +0000 (18:11 +0300)
* Rewrite special token handling from #1931

* shorten param name, add st verification by type

* use offsets instead of copy by substr

* formatting, remove copying iterator on delete

* llama : normalize code-style

* swift fix

* print pfx/sfx if verb, main: split pfx input sfx

* dont add space when using special tokens

* minor : comment + spacing

---------

Co-authored-by: Georgi Gerganov <redacted>
common/common.cpp
common/common.h
common/train.cpp
examples/batched.swift/Sources/main.swift
examples/main/main.cpp
llama.cpp
llama.h

index 9c4f7df204673ea1deb5e0b8ed9974851893dfb1..3e4b8a8cbdf7990edde1d9dd597c785e089b3795 100644 (file)
@@ -879,21 +879,23 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
 std::vector<llama_token> llama_tokenize(
   const struct llama_context * ctx,
            const std::string & text,
-                        bool   add_bos) {
-    return llama_tokenize(llama_get_model(ctx), text, add_bos);
+                        bool   add_bos,
+                        bool   special) {
+    return llama_tokenize(llama_get_model(ctx), text, add_bos, special);
 }
 
 std::vector<llama_token> llama_tokenize(
     const struct llama_model * model,
            const std::string & text,
-                        bool   add_bos) {
+                        bool   add_bos,
+                        bool   special) {
     // upper limit for the number of tokens
     int n_tokens = text.length() + add_bos;
     std::vector<llama_token> result(n_tokens);
-    n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
+    n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
     if (n_tokens < 0) {
         result.resize(-n_tokens);
-        int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
+        int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
         GGML_ASSERT(check == -n_tokens);
     } else {
         result.resize(n_tokens);
index 36fd4416649718553359f22209b2ab3c73a44398..08c6032315e875cf4090f5b3b8d8597ce4c47230 100644 (file)
@@ -137,12 +137,14 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
 std::vector<llama_token> llama_tokenize(
   const struct llama_context * ctx,
            const std::string & text,
-                        bool   add_bos);
+                        bool   add_bos,
+                        bool   special = false);
 
 std::vector<llama_token> llama_tokenize(
     const struct llama_model * model,
            const std::string & text,
-                        bool   add_bos);
+                        bool   add_bos,
+                        bool   special = false);
 
 // tokenizes a token into a piece
 // should work similar to Python's `tokenizer.id_to_piece`
index 35a4cf9e6cae39b16c7c991760b9ec98812d02a8..972eaefe00f05b807cefb8c3c115d513734e38b5 100644 (file)
@@ -863,7 +863,7 @@ size_t tokenize_file(
             (int) buf.size(),
             out_tokens.data(),
             (int) out_tokens.size(),
-            false);
+            false, false);
         if (n_tokens < 0) {
             out_tokens.resize(-n_tokens);
             n_tokens = llama_tokenize(
@@ -872,7 +872,7 @@ size_t tokenize_file(
                 (int) buf.size(),
                 out_tokens.data(),
                 (int) out_tokens.size(),
-                false);
+                false, false);
         }
         if (n_tokens >= 0) {
             out_tokens.resize(n_tokens);
@@ -966,7 +966,7 @@ size_t tokenize_file(
                     (int) buf_sample.size(),
                     tok_sample.data(),
                     (int) tok_sample.size(),
-                    false);
+                    false, false);
                 if (n_tokens < 0) {
                     tok_sample.resize(-n_tokens);
                     n_tokens = llama_tokenize(llama_get_model(lctx),
@@ -974,7 +974,7 @@ size_t tokenize_file(
                         (int) buf_sample.size(),
                         tok_sample.data(),
                         (int) tok_sample.size(),
-                        false);
+                        false, false);
                     GGML_ASSERT(n_tokens >= 0);
                 }
                 GGML_ASSERT(n_tokens <= (int) tok_sample.size());
index 938f30512ca6a8ec96f1666b04dea878ed2f7041..05d1bb9d00068faddf2816bf58f721409a68b8a6 100644 (file)
@@ -209,7 +209,7 @@ llama_print_timings(context)
 private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
     let n_tokens = text.count + (add_bos ? 1 : 0)
     let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
-    let tokenCount = llama_tokenize(model, text, Int32(text.count), tokens, Int32(n_tokens), add_bos)
+    let tokenCount = llama_tokenize(model, text, Int32(text.count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false)
     var swiftTokens: [llama_token] = []
     for i in 0 ..< tokenCount {
         swiftTokens.append(tokens[Int(i)])
index 55f73356fb89a1cd154aae18b44ebeb57ec4abc6..a5fb65548ff4f7003b14975a54b21d8291f2e955 100644 (file)
@@ -238,7 +238,7 @@ int main(int argc, char ** argv) {
 
     if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
         LOG("tokenize the prompt\n");
-        embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
+        embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
     } else {
         LOG("use session tokens\n");
         embd_inp = session_tokens;
@@ -260,10 +260,10 @@ int main(int argc, char ** argv) {
     if (ctx_guidance) {
         LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
 
-        guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos);
+        guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos, true);
         LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
 
-        std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
+        std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
         LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp));
 
         original_prompt_len = original_inp.size();
@@ -320,8 +320,8 @@ int main(int argc, char ** argv) {
     }
 
     // prefix & suffix for instruct mode
-    const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos);
-    const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n",    false);
+    const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true);
+    const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n",    false,   true);
 
     LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx));
     LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx));
@@ -383,6 +383,12 @@ int main(int argc, char ** argv) {
         if (!params.antiprompt.empty()) {
             for (const auto & antiprompt : params.antiprompt) {
                 LOG_TEE("Reverse prompt: '%s'\n", antiprompt.c_str());
+                if (params.verbose_prompt) {
+                    auto tmp = ::llama_tokenize(ctx, antiprompt, false, true);
+                    for (int i = 0; i < (int) tmp.size(); i++) {
+                        LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
+                    }
+                }
             }
         }
 
@@ -392,10 +398,22 @@ int main(int argc, char ** argv) {
 
         if (!params.input_prefix.empty()) {
             LOG_TEE("Input prefix: '%s'\n", params.input_prefix.c_str());
+            if (params.verbose_prompt) {
+                auto tmp = ::llama_tokenize(ctx, params.input_prefix, true, true);
+                for (int i = 0; i < (int) tmp.size(); i++) {
+                    LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
+                }
+            }
         }
 
         if (!params.input_suffix.empty()) {
             LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
+            if (params.verbose_prompt) {
+                auto tmp = ::llama_tokenize(ctx, params.input_suffix, false, true);
+                for (int i = 0; i < (int) tmp.size(); i++) {
+                    LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
+                }
+            }
         }
     }
     LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
@@ -717,7 +735,7 @@ int main(int argc, char ** argv) {
                 if (params.interactive) {
                     if (!params.antiprompt.empty()) {
                         // tokenize and inject first reverse prompt
-                        const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
+                        const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true);
                         embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
                         is_antiprompt = true;
                     }
@@ -744,8 +762,7 @@ int main(int argc, char ** argv) {
                 std::string buffer;
                 if (!params.input_prefix.empty()) {
                     LOG("appending input prefix: '%s'\n", params.input_prefix.c_str());
-                    buffer += params.input_prefix;
-                    printf("%s", buffer.c_str());
+                    printf("%s", params.input_prefix.c_str());
                 }
 
                 // color user input only
@@ -767,7 +784,6 @@ int main(int argc, char ** argv) {
                     // append input suffix if any
                     if (!params.input_suffix.empty()) {
                         LOG("appending input suffix: '%s'\n", params.input_suffix.c_str());
-                        buffer += params.input_suffix;
                         printf("%s", params.input_suffix.c_str());
                     }
 
@@ -782,10 +798,14 @@ int main(int argc, char ** argv) {
                         embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
                     }
 
-                    const auto line_inp = ::llama_tokenize(ctx, buffer, false);
+                    const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
+                    const auto line_inp = ::llama_tokenize(ctx, buffer, false, false);
+                    const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
                     LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));
 
+                    embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end());
                     embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
+                    embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
 
                     // instruct mode: insert response suffix
                     if (params.instruct) {
index 5329bd828a12575fed2c990ca31327af5beeb711..82b7638ae7ce1af6b007ec17db725ab0b7161aa9 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -75,6 +75,7 @@
 #include <thread>
 #include <unordered_map>
 #include <set>
+#include <forward_list>
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
@@ -1183,6 +1184,8 @@ struct llama_vocab {
     std::unordered_map<token, id> token_to_id;
     std::vector<token_data>       id_to_token;
 
+    std::unordered_map<token, id> special_tokens_cache;
+
     std::map<std::pair<std::string, std::string>, int> bpe_ranks;
 
     // default LLaMA special tokens
@@ -2125,7 +2128,7 @@ static void llm_load_hparams(
 }
 
 // TODO: This should probably be in llama.h
-static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos);
+static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special = false);
 static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch);
 
 static void llm_load_vocab(
@@ -2241,6 +2244,101 @@ static void llm_load_vocab(
     GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID));
     GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID));
     GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID));
+
+    // build special tokens cache
+    {
+        // TODO: It is unclear (to me) at this point, whether special tokes are guaranteed to be of a deterministic type,
+        //  and will always be correctly labeled in 'added_tokens.json' etc.
+        // The assumption is, since special tokens aren't meant to be exposed to end user, they are designed
+        //  to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer
+        //  are special tokens.
+        // From testing, this appears to corelate 1:1 with special tokens.
+        //
+
+        // Counting special tokens and verifying in only one direction
+        //  is sufficient to detect difference in those two sets.
+        //
+        uint32_t special_tokens_count_by_type = 0;
+        uint32_t special_tokens_count_from_verification = 0;
+
+        bool special_tokens_definition_mismatch = false;
+
+        for (const auto & t : vocab.token_to_id) {
+            const auto & token = t.first;
+            const auto & id    = t.second;
+
+            // Count all non-normal tokens in the vocab while iterating
+            if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
+                special_tokens_count_by_type++;
+            }
+
+            // Skip single character tokens
+            if (token.length() > 1) {
+                bool is_tokenizable = false;
+
+                // Split token string representation in two, in all possible ways
+                //  and check if both halves can be matched to a valid token
+                for (unsigned i = 1; i < token.length();) {
+                    const auto left  = token.substr(0, i);
+                    const auto right = token.substr(i);
+
+                    // check if we didnt partition in the middle of a utf sequence
+                    auto utf = utf8_len(left.at(left.length() - 1));
+
+                    if (utf == 1) {
+                        if (vocab.token_to_id.find(left)  != vocab.token_to_id.end() &&
+                            vocab.token_to_id.find(right) != vocab.token_to_id.end() ) {
+                            is_tokenizable = true;
+                            break;
+                        }
+                        i++;
+                    } else {
+                        // skip over the rest of multibyte utf sequence
+                        i += utf - 1;
+                    }
+                }
+
+                if (!is_tokenizable) {
+                    // Some tokens are multibyte, but they are utf sequences with equivalent text length of 1
+                    //  it's faster to re-filter them here, since there are way less candidates now
+
+                    // Calculate a total "utf" length of a token string representation
+                    size_t utf8_str_len = 0;
+                    for (unsigned i = 0; i < token.length();) {
+                        utf8_str_len++;
+                        i += utf8_len(token.at(i));
+                    }
+
+                    // And skip the ones which are one character
+                    if (utf8_str_len > 1) {
+                        // At this point what we have left are special tokens only
+                        vocab.special_tokens_cache[token] = id;
+
+                        // Count manually found special tokens
+                        special_tokens_count_from_verification++;
+
+                        // If this manually found special token is not marked as such, flag a mismatch
+                        if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) {
+                            special_tokens_definition_mismatch = true;
+                        }
+                    }
+                }
+            }
+        }
+
+        if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) {
+            fprintf(stderr, "%s: warning: Mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n",
+                __func__,
+                special_tokens_count_from_verification, vocab.id_to_token.size(),
+                special_tokens_count_by_type, vocab.id_to_token.size()
+            );
+        } else {
+            fprintf(stderr, "%s: Special tokens definition check successful ( %u/%zu ).\n",
+                __func__,
+                special_tokens_count_from_verification, vocab.id_to_token.size()
+            );
+        }
+    }
 }
 
 static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
@@ -6464,7 +6562,137 @@ private:
     llm_bigram_bpe::queue work_queue;
 };
 
-static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos) {
+typedef enum FRAGMENT_BUFFER_VARIANT_TYPE{
+    FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN,
+    FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT
+} FRAGMENT_BUFFER_VARIANT_TYPE;
+
+struct fragment_buffer_variant{
+    fragment_buffer_variant(llama_vocab::id _token)
+    :
+        type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN),
+        token(_token),
+        raw_text(_dummy),
+        offset(0),
+        length(0){}
+    fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length)
+    :
+        type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT),
+        token((llama_vocab::id)-1),
+        raw_text(_raw_text),
+        offset(_offset),
+        length(_length){
+            GGML_ASSERT( _offset >= 0 );
+            GGML_ASSERT( _length >= 1 );
+            GGML_ASSERT( offset + length <= raw_text.length() );
+        }
+
+    const FRAGMENT_BUFFER_VARIANT_TYPE type;
+    const llama_vocab::id token;
+    const std::string _dummy;
+    const std::string & raw_text;
+    const uint64_t offset;
+    const uint64_t length;
+};
+
+// #define PRETOKENIZERDEBUG
+
+static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer)
+{
+    // for each special token
+    for (const auto & st: vocab.special_tokens_cache) {
+        const auto & special_token = st.first;
+        const auto & special_id    = st.second;
+
+        // for each text fragment
+        std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
+        while (it != buffer.end()) {
+            auto & fragment = (*it);
+
+            // if a fragment is text ( not yet processed )
+            if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                auto * raw_text = &(fragment.raw_text);
+
+                auto raw_text_base_offset = fragment.offset;
+                auto raw_text_base_length = fragment.length;
+
+                // loop over the text
+                while (true) {
+                    // find the first occurence of a given special token in this fragment
+                    //  passing offset argument only limit the "search area" but match coordinates
+                    //  are still relative to the source full raw_text
+                    auto match = raw_text->find(special_token, raw_text_base_offset);
+
+                    // no occurences found, stop processing this fragment for a given special token
+                    if (match == std::string::npos) break;
+
+                    // check if match is within bounds of offset <-> length
+                    if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break;
+
+#ifdef PRETOKENIZERDEBUG
+                    fprintf(stderr, "FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
+#endif
+                    auto source = std::distance(buffer.begin(), it);
+
+                    // if match is further than base offset
+                    //  then we have some text to the left of it
+                    if (match > raw_text_base_offset) {
+                        // left
+                        const int64_t left_reminder_offset = raw_text_base_offset + 0;
+                        const int64_t left_reminder_length = match - raw_text_base_offset;
+                        buffer.emplace_after(it, (*raw_text), left_reminder_offset, left_reminder_length);
+
+#ifdef PRETOKENIZERDEBUG
+                        fprintf(stderr, "FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
+#endif
+                        it++;
+                    }
+
+                    // special token
+                    buffer.emplace_after(it, special_id);
+                    it++;
+
+                    // right
+                    if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
+                        const int64_t right_reminder_offset = match + special_token.length();
+                        const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
+                        buffer.emplace_after(it, (*raw_text), right_reminder_offset, right_reminder_length);
+
+#ifdef PRETOKENIZERDEBUG
+                        fprintf(stderr, "FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
+#endif
+
+                        it++;
+
+                        if (source == 0) {
+                            buffer.erase_after(buffer.before_begin());
+                        } else {
+                            buffer.erase_after(std::next(buffer.begin(), (source-1)));
+                        }
+
+                        // repeat for the right side
+                        raw_text_base_offset = right_reminder_offset;
+                        raw_text_base_length = right_reminder_length;
+
+#ifdef PRETOKENIZERDEBUG
+                        fprintf(stderr, "RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
+#endif
+                    } else {
+                        if (source == 0) {
+                            buffer.erase_after(buffer.before_begin());
+                        } else {
+                            buffer.erase_after(std::next(buffer.begin(), (source-1)));
+                        }
+                        break;
+                    }
+                }
+            }
+            it++;
+        }
+    }
+}
+
+static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special) {
     std::vector<llama_vocab::id> output;
 
     // OG tokenizer behavior:
@@ -6480,20 +6708,58 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
         return output;
     }
 
+    std::forward_list<fragment_buffer_variant> fragment_buffer;
+    fragment_buffer.emplace_front( raw_text, 0, raw_text.length() );
+
+    if (special) tokenizer_st_partition( vocab, fragment_buffer );
+
     switch (vocab.type) {
         case LLAMA_VOCAB_TYPE_SPM:
             {
-                // without adding this leading whitespace, we do not get the same results as the original tokenizer
-                raw_text = " " + raw_text;
+                for (const auto & fragment: fragment_buffer)
+                {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
+                    {
+                        // without adding this leading whitespace, we do not get the same results as the original tokenizer
 
-                llm_tokenizer_spm tokenizer(vocab);
-                llama_escape_whitespace(raw_text);
-                tokenizer.tokenize(raw_text, output);
+                        // TODO: It's likely possible to get rid of this string copy entirely
+                        //  by modifying llm_tokenizer_x to operate with string offsets like pre-tokenizer
+                        //  and passing 'add space prefix' as bool argument
+                        //
+                        auto raw_text = (special ? "" : " ") + fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        fprintf(stderr,"TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
+#endif
+                        llm_tokenizer_spm tokenizer(vocab);
+                        llama_escape_whitespace(raw_text);
+                        tokenizer.tokenize(raw_text, output);
+                    }
+                    else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                    {
+                        output.push_back(fragment.token);
+                    }
+                }
             } break;
         case LLAMA_VOCAB_TYPE_BPE:
             {
-                llm_tokenizer_bpe tokenizer(vocab);
-                tokenizer.tokenize(raw_text, output);
+                for (const auto & fragment: fragment_buffer)
+                {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT)
+                    {
+                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        fprintf(stderr,"TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
+#endif
+                        llm_tokenizer_bpe tokenizer(vocab);
+                        tokenizer.tokenize(raw_text, output);
+                    }
+                    else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                    {
+                        output.push_back(fragment.token);
+                    }
+                }
             } break;
     }
 
@@ -9407,15 +9673,15 @@ llama_token llama_token_eot(const struct llama_context * ctx) {
     return ctx->model.vocab.special_eot_id;
 }
 
-
 int llama_tokenize(
     const struct llama_model * model,
                   const char * text,
                          int   text_len,
                  llama_token * tokens,
                          int   n_max_tokens,
-                        bool   add_bos) {
-    auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos);
+                        bool   add_bos,
+                        bool   special) {
+    auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special);
 
     if (n_max_tokens < (int) res.size()) {
         // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
diff --git a/llama.h b/llama.h
index a78015adab30c37179a87249684df66963d62204..b13f23123390775ddb2cc3ff89fca3425a1fe3ea 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -511,17 +511,20 @@ extern "C" {
     // Tokenization
     //
 
-    // Convert the provided text into tokens.
-    // The tokens pointer must be large enough to hold the resulting tokens.
-    // Returns the number of tokens on success, no more than n_max_tokens
-    // Returns a negative number on failure - the number of tokens that would have been returned
+    /// @details Convert the provided text into tokens.
+    /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
+    /// @return Returns the number of tokens on success, no more than n_max_tokens
+    /// @return Returns a negative number on failure - the number of tokens that would have been returned
+    /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
+    ///                Does not insert a leading space.
     LLAMA_API int llama_tokenize(
         const struct llama_model * model,
                       const char * text,
                              int   text_len,
                      llama_token * tokens,
                              int   n_max_tokens,
-                            bool   add_bos);
+                            bool   add_bos,
+                            bool   special);
 
     // Token Id -> Piece.
     // Uses the vocabulary in the provided context.