]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : more tokenizer fixes (#2810)
authorGeorgi Gerganov <redacted>
Sun, 27 Aug 2023 11:19:19 +0000 (14:19 +0300)
committerGitHub <redacted>
Sun, 27 Aug 2023 11:19:19 +0000 (14:19 +0300)
* tests : write a Python tokenizer test (wip)

* llama : prefix input text for tokenization with whitespace

* llama : distinguish pieces from decoded text + fix detokenization

* common : add comments

* examples : no longer manually add leading space when tokenizing

* tests : use Python to generate tokenizer tests for C++

* tests : add option to tokenize text files

ggml-ci

* tests : add test-tokenizer-1.py

* llama.cpp : fix LF token

* hellaswag : move the concat space for clarity

* tests : add falcon tests (py + cpp, currently do not pass Unicode)

ggml-ci

* common : temporary separate llama_detokenize calls for SPM and BPE

---------

Co-authored-by: klosax <redacted>
20 files changed:
common/common.cpp
common/common.h
examples/beam_search/beam_search.cpp
examples/embd-input/embd-input-lib.cpp
examples/embedding/embedding.cpp
examples/main/main.cpp
examples/perplexity/perplexity.cpp
examples/save-load-state/save-load-state.cpp
examples/server/server.cpp
examples/simple/simple.cpp
examples/train-text-from-scratch/train-text-from-scratch.cpp
llama.cpp
llama.h
tests/CMakeLists.txt
tests/test-tokenizer-0-falcon.cpp [new file with mode: 0644]
tests/test-tokenizer-0-falcon.py [new file with mode: 0644]
tests/test-tokenizer-0-llama.cpp [new file with mode: 0644]
tests/test-tokenizer-0-llama.py [new file with mode: 0644]
tests/test-tokenizer-0.cpp [deleted file]
tests/test-tokenizer-1.cpp

index ff19ec4e50f60485b66ec940b9d50937c3e105fe..0d91a6a35acaaaa99ef4562f08186e97cd31fd19 100644 (file)
@@ -733,12 +733,12 @@ std::vector<llama_token> llama_tokenize(
     return result;
 }
 
-std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) {
+std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
     std::vector<char> result(8, 0);
-    const int n_tokens = llama_token_to_str(ctx, token, result.data(), result.size());
+    const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size());
     if (n_tokens < 0) {
         result.resize(-n_tokens);
-        int check = llama_token_to_str(ctx, token, result.data(), result.size());
+        int check = llama_token_to_piece(ctx, token, result.data(), result.size());
         GGML_ASSERT(check == -n_tokens);
     } else {
         result.resize(n_tokens);
@@ -746,3 +746,36 @@ std::string llama_token_to_str(const struct llama_context * ctx, llama_token tok
 
     return std::string(result.data(), result.size());
 }
+
+std::string llama_detokenize_spm(llama_context * ctx, const std::vector<llama_token> & tokens) {
+    const llama_token bos_id = llama_token_bos(ctx);
+
+    std::string piece;
+    std::string result;
+
+    for (size_t i = 0; i < tokens.size(); ++i) {
+        piece = llama_token_to_piece(ctx, tokens[i]);
+
+        // remove the leading space of the first non-BOS token
+        if (((tokens[0] == bos_id && i == 1) || (tokens[0] != bos_id && i == 0)) && piece[0] == ' ') {
+            piece = piece.substr(1);
+        }
+
+        result += piece;
+    }
+
+    return result;
+}
+
+std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_token> & tokens) {
+    std::string piece;
+    std::string result;
+
+    for (size_t i = 0; i < tokens.size(); ++i) {
+        piece = llama_token_to_piece(ctx, tokens[i]);
+
+        result += piece;
+    }
+
+    return result;
+}
index ce61265f8c12472427ac83be17653f9cb980d189..97fda2be78b5188d717dbff1b3defc2eb8a9ab85 100644 (file)
@@ -116,11 +116,31 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
 // Vocab utils
 //
 
+// tokenizes a string into a vector of tokens
+// should work similar to Python's `tokenizer.encode`
 std::vector<llama_token> llama_tokenize(
         struct llama_context * ctx,
            const std::string & text,
                         bool   add_bos);
 
-std::string llama_token_to_str(
+// tokenizes a token into a piece
+// should work similar to Python's `tokenizer.id_to_piece`
+std::string llama_token_to_piece(
         const struct llama_context * ctx,
                        llama_token   token);
+
+// TODO: these should be moved in llama.h C-style API under single `llama_detokenize` function
+//       that takes into account the tokenizer type and decides how to handle the leading space
+//
+// detokenizes a vector of tokens into a string
+// should work similar to Python's `tokenizer.decode`
+// removes the leading space from the first non-BOS token
+std::string llama_detokenize_spm(
+                         llama_context * ctx,
+        const std::vector<llama_token> & tokens);
+
+// detokenizes a vector of tokens into a string
+// should work similar to Python's `tokenizer.decode`
+std::string llama_detokenize_bpe(
+                         llama_context * ctx,
+        const std::vector<llama_token> & tokens);
index 1c04fabc21b3d82aea98df46d72364b6f5b074c6..42c7c72542321c5c5beaf7e84db96a5b99264e2c 100644 (file)
@@ -35,7 +35,7 @@ struct ostream_beam_view {
 std::ostream& operator<<(std::ostream& os, const ostream_beam_view & obv) {
     os << "p(" << obv.beam_view.p << ") eob(" << std::boolalpha << obv.beam_view.eob << ") tokens(";
     for (size_t i = 0 ; i < obv.beam_view.n_tokens ; ++i) {
-        os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]);
+        os << llama_token_to_piece(obv.ctx, obv.beam_view.tokens[i]);
     }
     return os << ')';
 }
@@ -156,7 +156,7 @@ int main(int argc, char ** argv)
 
     for( auto id : tokens_list )
     {
-        std::cout << llama_token_to_str(ctx, id);
+        std::cout << llama_token_to_piece(ctx, id);
     }
     std::cout << std::flush;
 
@@ -175,7 +175,7 @@ int main(int argc, char ** argv)
 
     std::cout << "\n\n";
     for (llama_token const token_id : callback_data.response) {
-        std::cout << llama_token_to_str(ctx,token_id);
+        std::cout << llama_token_to_piece(ctx,token_id);
     }
     std::cout << std::endl;
 
index 8a6ad882e8fa8c24ccd8a81a1b3eb52993e9969f..036bdb3987f34752ed7de593ba7d562355dc9d03 100644 (file)
@@ -214,7 +214,7 @@ const char * sampling(struct MyModel * mymodel) {
     if (id == llama_token_eos(ctx)) {
         ret = "</s>";
     } else {
-        ret = llama_token_to_str(ctx, id);
+        ret = llama_token_to_piece(ctx, id);
     }
     eval_id(mymodel, id);
     return ret.c_str();
index 38395c75b0b5bc3ecbde8fd9413082476a7d4859..93d583b5ce15170358e648dc4251d55ca179feb6 100644 (file)
@@ -56,9 +56,6 @@ int main(int argc, char ** argv) {
 
     int n_past = 0;
 
-    // Add a space in front of the first character to match OG llama tokenizer behavior
-    params.prompt.insert(0, 1, ' ');
-
     // tokenize the prompt
     auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
 
@@ -67,7 +64,7 @@ int main(int argc, char ** argv) {
         fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
         fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
         for (int i = 0; i < (int) embd_inp.size(); i++) {
-            fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]).c_str());
+            fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
         }
         fprintf(stderr, "\n");
     }
index 11d7a7e4f4fe1a0853f66afb44d105a830af2af0..3ce57f436b89340f19aad550b1c94ebaa08cde70 100644 (file)
@@ -195,11 +195,6 @@ int main(int argc, char ** argv) {
     // tokenize the prompt
     std::vector<llama_token> embd_inp;
 
-    if (llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM) {
-        // Add a space in front of the first character to match OG llama tokenizer behavior
-        params.prompt.insert(0, 1, ' ');
-    }
-
     if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
         embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
     } else {
@@ -216,7 +211,6 @@ int main(int argc, char ** argv) {
     int guidance_offset = 0;
     int original_prompt_len = 0;
     if (ctx_guidance) {
-        params.cfg_negative_prompt.insert(0, 1, ' ');
         guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos);
 
         std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
@@ -285,7 +279,7 @@ int main(int argc, char ** argv) {
         fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
         fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
         for (int i = 0; i < (int) embd_inp.size(); i++) {
-            fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]).c_str());
+            fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
         }
 
         if (ctx_guidance) {
@@ -293,14 +287,14 @@ int main(int argc, char ** argv) {
             fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str());
             fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
             for (int i = 0; i < (int) guidance_inp.size(); i++) {
-                fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i]).c_str());
+                fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str());
             }
         }
 
         if (params.n_keep > 0) {
         fprintf(stderr, "%s: static prompt based on n_keep: '", __func__);
             for (int i = 0; i < params.n_keep; i++) {
-                fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i]).c_str());
+                fprintf(stderr, "%s", llama_token_to_piece(ctx, embd_inp[i]).c_str());
             }
             fprintf(stderr, "'\n");
         }
@@ -456,7 +450,7 @@ int main(int argc, char ** argv) {
                 //printf("\n---\n");
                 //printf("resetting: '");
                 //for (int i = 0; i < (int) embd.size(); i++) {
-                //    printf("%s", llama_token_to_str(ctx, embd[i]));
+                //    printf("%s", llama_token_to_piece(ctx, embd[i]));
                 //}
                 //printf("'\n");
                 //printf("\n---\n");
@@ -509,7 +503,7 @@ int main(int argc, char ** argv) {
                     input_size = embd_guidance.size();
                     //fprintf(stderr, "\n---------------------\n");
                     //for (int i = 0; i < (int) embd_guidance.size(); i++) {
-                        //fprintf(stderr, "%s", llama_token_to_str(ctx, embd_guidance[i]));
+                        //fprintf(stderr, "%s", llama_token_to_piece(ctx, embd_guidance[i]));
                     //}
                     //fprintf(stderr, "\n---------------------\n");
                 } else {
@@ -673,7 +667,7 @@ int main(int argc, char ** argv) {
         // display text
         if (input_echo) {
             for (auto id : embd) {
-                printf("%s", llama_token_to_str(ctx, id).c_str());
+                printf("%s", llama_token_to_piece(ctx, id).c_str());
             }
             fflush(stdout);
         }
@@ -689,7 +683,7 @@ int main(int argc, char ** argv) {
             if (params.antiprompt.size()) {
                 std::string last_output;
                 for (auto id : last_n_tokens) {
-                    last_output += llama_token_to_str(ctx, id);
+                    last_output += llama_token_to_piece(ctx, id);
                 }
 
                 is_antiprompt = false;
index fd89852d6deddb9dc0e88e5b52d7eb7be7297ca6..b596d062613d7cd66c57053f9c1524f3b2cd8baf 100644 (file)
@@ -392,7 +392,7 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
         hs_data[i].context = prompt_lines[idx*6];
         hs_data[i].gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
         for (size_t j=0; j < 4; j++) {
-            hs_data[i].ending[j] = " " + prompt_lines[idx*6+2+j];
+            hs_data[i].ending[j] = prompt_lines[idx*6+2+j];
         }
 
         // Delete the selected random example from the prompt
@@ -417,7 +417,7 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
         size_t context_size = context_embd.size();
 
         for (int i = 0; i < 4; ++i) {
-            ending_tokens[i] = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[i], add_bos);
+            ending_tokens[i] = ::llama_tokenize(ctx, hs_data[task_idx].context + " " + hs_data[task_idx].ending[i], add_bos);
             for (int k = 0; k < int(context_size); ++k) {
                 if (ending_tokens[i][k] != context_embd[k]) {
                     fprintf(stderr, "Oops: ending %d of task %d differs from context at position %d\n",i,int(task_idx),k);
index 3db61b7541171ed409465b31768e204b5a8aab1a..573bc4ef988a69e60b6d51de27ded2c1291d5629 100644 (file)
@@ -87,7 +87,7 @@ int main(int argc, char ** argv) {
         }
         llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
         auto next_token = llama_sample_token(ctx, &candidates_p);
-        auto next_token_str = llama_token_to_str(ctx, next_token);
+        auto next_token_str = llama_token_to_piece(ctx, next_token);
         last_n_tokens_data.push_back(next_token);
 
         printf("%s", next_token_str.c_str());
@@ -147,7 +147,7 @@ int main(int argc, char ** argv) {
         }
         llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
         auto next_token = llama_sample_token(ctx2, &candidates_p);
-        auto next_token_str = llama_token_to_str(ctx2, next_token);
+        auto next_token_str = llama_token_to_piece(ctx2, next_token);
         last_n_tokens_data.push_back(next_token);
 
         printf("%s", next_token_str.c_str());
index a4b4d641859365ef4572bd9eceda0926bb183ec7..89a3311f5432977639b90af608bdb58a4e287bf6 100644 (file)
@@ -94,7 +94,7 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end)
     std::string ret;
     for (; begin != end; ++begin)
     {
-        ret += llama_token_to_str(ctx, *begin);
+        ret += llama_token_to_piece(ctx, *begin);
     }
     return ret;
 }
@@ -123,7 +123,7 @@ static void server_log(const char *level, const char *function, int line,
 // format incomplete utf-8 multibyte character for output
 static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token)
 {
-    std::string out = token == -1 ? "" : llama_token_to_str(ctx, token);
+    std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token);
     // if the size is 1 and first bit is 1, meaning it's a partial character
     //   (size > 1 meaning it's already a known token)
     if (out.size() == 1 && (out[0] & 0x80) == 0x80)
@@ -286,7 +286,6 @@ struct llama_server_context
                     std::vector<llama_token> p;
                     if (first)
                     {
-                        s.insert(0, 1, ' '); // add a space if it's the first
                         p = ::llama_tokenize(ctx, s, add_bos);
                         first = false;
                     }
@@ -309,7 +308,6 @@ struct llama_server_context
         else
         {
             auto s = json_prompt.template get<std::string>();
-            s.insert(0, 1, ' '); // always add a first space
             prompt_tokens = ::llama_tokenize(ctx, s, add_bos);
         }
 
@@ -566,7 +564,7 @@ struct llama_server_context
 
         if (!embd.empty() && embd.back() == llama_token_eos(ctx))
         {
-            // stopping_word = llama_token_to_str(ctx, embd.back());
+            // stopping_word = llama_token_to_piece(ctx, embd.back());
             has_next_token = false;
             stopped_eos = true;
             LOG_VERBOSE("eos token found", {});
@@ -613,7 +611,7 @@ struct llama_server_context
     {
         const completion_token_output token_with_probs = nextToken();
 
-        const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(ctx, token_with_probs.tok);
+        const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok);
         generated_text += token_text;
 
         if (params.n_probs > 0)
@@ -1254,7 +1252,7 @@ void beam_search_callback(void * callback_data, llama_beams_state beams_state) {
 
 struct token_translator {
     llama_context * ctx;
-    std::string operator()(llama_token tok) const { return llama_token_to_str(ctx, tok); }
+    std::string operator()(llama_token tok) const { return llama_token_to_piece(ctx, tok); }
     std::string operator()(completion_token_output cto) const { return (*this)(cto.tok); }
 };
 
@@ -1364,7 +1362,7 @@ int main(int argc, char **argv)
 
                 while (llama.has_next_token) {
                     const completion_token_output token_with_probs = llama.doCompletion();
-                    const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);
+                    const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(llama.ctx, token_with_probs.tok);
 
                     stop_pos = llama.findStoppingStrings(llama.generated_text,
                         token_text.size(), STOP_FULL);
@@ -1395,7 +1393,7 @@ int main(int argc, char **argv)
                     if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) {
                         continue;
                     }
-                    const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok);
+                    const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok);
 
                     size_t pos = std::min(sent_count, llama.generated_text.size());
 
index 132f7fbf912bb042699237285a93796f18af0d6d..4ee85faca9f4a93c3a4b8b5e12e7a9ad8a04c365 100644 (file)
@@ -63,7 +63,7 @@ int main(int argc, char ** argv) {
     fprintf(stderr, "\n\n");
 
     for (auto id : tokens_list) {
-        fprintf(stderr, "%s", llama_token_to_str(ctx, id).c_str());
+        fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
     }
 
     fflush(stderr);
@@ -112,7 +112,7 @@ int main(int argc, char ** argv) {
         }
 
         // print the new token :
-        printf("%s", llama_token_to_str(ctx, new_token_id).c_str());
+        printf("%s", llama_token_to_piece(ctx, new_token_id).c_str());
         fflush(stdout);
 
         // push this new token for next evaluation
index 79b117df72fd33efe5ccabdc8f8e7c93f4894f87..12d153417968b92fb8a688f51ec554b452d68908 100644 (file)
@@ -1964,7 +1964,7 @@ void print_matrix(struct ggml_tensor * probs) {
 
 
 void print_token(struct llama_context * ctx, llama_token token) {
-    printf("%s", llama_token_to_str(ctx, token).c_str());
+    printf("%s", llama_token_to_piece(ctx, token).c_str());
 }
 
 void print_tokens(struct llama_context* ctx, struct ggml_tensor * tokens) {
@@ -2202,7 +2202,7 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto
         const char * in  = buf.data();
         const char * end = buf.data() + buf.size();
         for (int i = 0; i < (int) out.size(); ++i) {
-            std::string s = llama_token_to_str(lctx, out[i]);
+            std::string s = llama_token_to_piece(lctx, out[i]);
             int len = s.length();
             if (in >= end) {
                 printf("%s: unexpected end of original text.\n", __func__);
index e956c0163901d96fcc50e59f5e863c8d5354f38e..2a8af4ee9d27d0fbb481efccf15ae0152db7eda5 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -796,12 +796,12 @@ static void llama_nop(struct ggml_tensor * tensor) { // don't offload by default
     (void) tensor;
 }
 
-static std::string llama_token_to_text(const struct llama_context * ctx, llama_token token) {
+static std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) {
     std::vector<char> result(8, 0);
-    const int n_tokens = llama_token_to_str(ctx, token, result.data(), result.size());
+    const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size());
     if (n_tokens < 0) {
         result.resize(-n_tokens);
-        int check = llama_token_to_str(ctx, token, result.data(), result.size());
+        int check = llama_token_to_piece(ctx, token, result.data(), result.size());
         GGML_ASSERT(check == -n_tokens);
     } else {
         result.resize(n_tokens);
@@ -1635,7 +1635,8 @@ static void llm_load_hparams(
 }
 
 // TODO: This should probably be in llama.h
-static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos);
+static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos);
+static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch);
 
 static void llm_load_vocab(
         llama_model_loader & ml,
@@ -1737,7 +1738,11 @@ static void llm_load_vocab(
     }
 
     // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
-    vocab.linefeed_id = llama_tokenize_internal(vocab, "\n", false)[0];
+    if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
+        vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
+    } else {
+        vocab.linefeed_id = llama_tokenize_internal(vocab, "\n", false)[0];
+    }
 
     // special tokens
     GGUF_GET_KEY(ctx, vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID));
@@ -3026,10 +3031,8 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
     return vocab.token_to_id.at(buf);
 }
 
-static std::string llama_escape_whitespace(const std::string& text) {
-    std::string result = text;
-    replace_all(result, " ", "\xe2\x96\x81");
-    return result;
+static void llama_escape_whitespace(std::string & text) {
+    replace_all(text, " ", "\xe2\x96\x81");
 }
 
 static void llama_unescape_whitespace(std::string & word) {
@@ -3373,22 +3376,31 @@ private:
     llm_bigram_bpe::queue work_queue;
 };
 
-static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos) {
+static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos) {
     std::vector<llama_vocab::id> output;
 
-    if (raw_text.empty()) {
-        return output;
-    }
+    // OG tokenizer behavior:
+    //
+    // tokenizer.encode('', add_bos=True)  returns [1]
+    // tokenizer.encode('', add_bos=False) returns []
 
     if (bos && vocab.special_bos_id != -1) {
         output.push_back(vocab.special_bos_id);
     }
 
+    if (raw_text.empty()) {
+        return output;
+    }
+
     switch (vocab.type) {
         case LLAMA_VOCAB_TYPE_SPM:
             {
+                // without adding this leading whitespace, we do not get the same results as the original tokenizer
+                raw_text = " " + raw_text;
+
                 llm_tokenizer_spm tokenizer(vocab);
-                tokenizer.tokenize(llama_escape_whitespace(raw_text), output);
+                llama_escape_whitespace(raw_text);
+                tokenizer.tokenize(raw_text, output);
             } break;
         case LLAMA_VOCAB_TYPE_BPE:
             {
@@ -4078,16 +4090,16 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
     std::vector<llama_grammar_candidate>                              candidates_grammar;
 
     for (size_t i = 0; i < candidates->size; ++i) {
-        const llama_token id   = candidates->data[i].id;
-        const std::string text = llama_token_to_text(ctx, id);
+        const llama_token id    = candidates->data[i].id;
+        const std::string piece = llama_token_to_str(ctx, id);
         if (id == eos) {
             if (!allow_eos) {
                 candidates->data[i].logit = -INFINITY;
             }
-        } else if (text.empty() || text[0] == 0) {
+        } else if (piece.empty() || piece[0] == 0) {
             candidates->data[i].logit = -INFINITY;
         } else {
-            candidates_decoded.push_back(decode_utf8(text.c_str(), grammar->partial_utf8));
+            candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8));
             candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
         }
     }
@@ -4291,10 +4303,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
         GGML_ASSERT(false);
     }
 
-    const std::string text = llama_token_to_text(ctx, token);
+    const std::string piece = llama_token_to_str(ctx, token);
 
     // Note terminating 0 in decoded string
-    const auto   decoded     = decode_utf8(text.c_str(), grammar->partial_utf8);
+    const auto   decoded     = decode_utf8(piece.c_str(), grammar->partial_utf8);
     const auto & code_points = decoded.first;
     for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
         grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
@@ -6101,12 +6113,12 @@ int llama_tokenize_with_model(
     return res.size();
 }
 
-int llama_token_to_str(const struct llama_context * ctx, llama_token token, char * buf, int length) {
-    return llama_token_to_str_with_model(&ctx->model, token, buf, length);
+int llama_token_to_piece(const struct llama_context * ctx, llama_token token, char * buf, int length) {
+    return llama_token_to_piece_with_model(&ctx->model, token, buf, length);
 }
 
-// does not write null-terminator to str
-int llama_token_to_str_with_model(const struct llama_model * model, llama_token token, char * buf, int length) {
+// does not write null-terminator to buf
+int llama_token_to_piece_with_model(const struct llama_model * model, llama_token token, char * buf, int length) {
     if (0 <= token && token < llama_model_n_vocab(model)) {
         if (llama_is_normal_token(model->vocab, token)) {
             std::string result = model->vocab.id_to_token[token].text;
diff --git a/llama.h b/llama.h
index b77dd7735fdf094592a6fbf3e0131295508699ce..b084fe23c8fcc84a60bd8e0de87720fc362b59ac 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -381,15 +381,17 @@ extern "C" {
                              int   n_max_tokens,
                             bool   add_bos);
 
-    // Token Id -> String. Uses the vocabulary in the provided context
-    // Does not write null terminator to the buffer
-    LLAMA_API int llama_token_to_str(
+    // Token Id -> Piece.
+    // Uses the vocabulary in the provided context.
+    // Does not write null terminator to the buffer.
+    // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
+    LLAMA_API int llama_token_to_piece(
             const struct llama_context * ctx,
                            llama_token   token,
                                   char * buf,
                                   int    length);
 
-    LLAMA_API int llama_token_to_str_with_model(
+    LLAMA_API int llama_token_to_piece_with_model(
               const struct llama_model * model,
                            llama_token   token,
                                   char * buf,
index 2afaf86b114504bb8d52d18dfd3d6f0d729c0af3..ca1f39d31b08184c53a27573f11ab7141a374c08 100644 (file)
@@ -25,8 +25,10 @@ endfunction()
 llama_build_and_test_executable(test-quantize-fns.cpp)
 llama_build_and_test_executable(test-quantize-perf.cpp)
 llama_build_and_test_executable(test-sampling.cpp)
-llama_build_executable(test-tokenizer-0.cpp)
-llama_test_executable (test-tokenizer-0.llama test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
+llama_build_executable(test-tokenizer-0-llama.cpp)
+llama_test_executable (test-tokenizer-0-llama test-tokenizer-0-llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf)
+llama_build_executable(test-tokenizer-0-falcon.cpp)
+#llama_test_executable (test-tokenizer-0-falcon test-tokenizer-0-falcon.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
 llama_build_executable(test-tokenizer-1.cpp)
 # test-tokenizer-1 requires a BPE vocab. re-enable when we have one.
 #llama_test_executable (test-tokenizer-1.llama test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
diff --git a/tests/test-tokenizer-0-falcon.cpp b/tests/test-tokenizer-0-falcon.cpp
new file mode 100644 (file)
index 0000000..836fb8a
--- /dev/null
@@ -0,0 +1,178 @@
+#include "llama.h"
+#include "common.h"
+
+#include <cstdio>
+#include <string>
+#include <map>
+#include <vector>
+#include <fstream>
+
+// generate using test-tokenizer-0-falcon.py
+static const std::map<std::string, std::vector<llama_token>> & k_tests() {
+    static std::map<std::string, std::vector<llama_token>> _k_tests = {
+        { ""                      , {  }, },
+        { " "                     , {     204, }, },
+        { "  "                    , {     258, }, },
+        { "   "                   , {     466, }, },
+        { "\t"                    , {     192, }, },
+        { "\n"                    , {     193, }, },
+        { "\t\n"                  , {   19125, }, },
+        { "Hello world"           , {    9856,   1079, }, },
+        { " Hello world"          , {   23090,   1079, }, },
+        { "Hello World"           , {    9856,   2889, }, },
+        { " Hello World"          , {   23090,   2889, }, },
+        { " Hello World!"         , {   23090,   2889,     12, }, },
+        { "Hello, world!"         , {    9856,     23,   1079,     12, }, },
+        { " Hello, world!"        , {   23090,     23,   1079,     12, }, },
+        { " this is πŸ¦™.cpp"        , {     414,    304,   3346,    111,    231,     25,  29247, }, },
+        { "w048 7tuijk dsdfhu"    , {      98,  55866,    204,     34,  16682,   7149,  36190,   6869,  11481, }, },
+        { "Π½Π΅Ρ‰ΠΎ Π½Π° Π‘ΡŠΠ»Π³Π°Ρ€ΡΠΊΠΈ"     , {     150,    133,   6207,    151,    215,    150,    134,   5052,    133,   6279,   5052,    223,    151,    216,  49679,    123,  53110,  47043,   7795, }, },
+        { "αž€αžΆαž“αŸ‹αžαŸ‚αž–αž·αžŸαŸαžŸαž’αžΆαž…αžαž›αž…αŸαž‰"   , {   38154,    206,  38154,    126,  38154,    225,    167,    237,    217,  38154,    221,    167,    237,    208,  38154,    228,  38154,    127,  38154,    237,    167,    237,    207,  38154,    237,  38154,    107,  38154,    126,  38154,    211,  38154,    207,  38154,    233,  38154,    211,    167,    237,    207,  38154,    215, }, },
+        { "πŸš€ (normal) πŸ˜Άβ€πŸŒ«οΈ (multiple emojis concatenated) βœ… (only emoji that has its own token)", {    2571,    232,    206,    204,     19,  11003,     20,   8196,    126,    283,    219,  48778,    116,  13392,    204,     19,  51831,    732,  63209,   1741,   7955,    522,     20,  22438,    211,    204,     19,   7927,  53360,    325,    504,    701,    946,  10930,     20, }, },
+        { "Hello"                 , {    9856, }, },
+        { " Hello"                , {   23090, }, },
+        { "  Hello"               , {     204,  23090, }, },
+        { "   Hello"              , {     258,  23090, }, },
+        { "    Hello"             , {     466,  23090, }, },
+        { "    Hello\n    Hello"  , {     466,  23090,    742,  23090, }, },
+    };
+
+    return _k_tests;
+}
+
+int main(int argc, char **argv) {
+    if (argc < 2) {
+        fprintf(stderr, "Usage: %s vocab-file [text-file]\n", argv[0]);
+        return 1;
+    }
+
+    const std::string fname = argv[1];
+
+    std::string fname_text;
+    if (argc > 2) {
+        fname_text = argv[2];
+    }
+
+    fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
+
+    llama_model * model;
+    llama_context * ctx;
+
+    llama_backend_init(false);
+
+    // load the vocab
+    {
+        auto lparams = llama_context_default_params();
+
+        lparams.vocab_only = true;
+
+        model = llama_load_model_from_file(fname.c_str(), lparams);
+
+        if (model == NULL) {
+            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
+            return 1;
+        }
+
+        ctx = llama_new_context_with_model(model, lparams);
+
+        if (ctx == NULL) {
+            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
+            llama_free_model(model);
+            return 1;
+        }
+    }
+
+    if (llama_vocab_type(ctx) != LLAMA_VOCAB_TYPE_BPE) {
+        fprintf(stderr, "%s : error: vocab type is not SPM\n", __func__);
+        llama_free_model(model);
+        llama_free(ctx);
+        return 2;
+    }
+
+    bool success = true;
+
+    for (const auto & test_kv : k_tests()) {
+        const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, false);
+
+        printf("\n");
+        printf("src: '%s'\n", test_kv.first.c_str());
+        printf("res: '%s'\n", llama_detokenize_bpe(ctx, res).c_str());
+        printf("tok: ");
+        for (const auto & tok : res) {
+            printf("%d ", tok);
+        }
+        printf("\n");
+
+        bool correct = res.size() == test_kv.second.size();
+
+        for (int i = 0; i < (int) res.size() && correct; ++i) {
+            if (test_kv.second[i] != res[i]) {
+                correct = false;
+            }
+        }
+
+        if (!correct) {
+            fprintf(stderr, "%s : failed test:    '%s'\n", __func__, test_kv.first.c_str());
+            fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
+                llama_detokenize_bpe(ctx, res).c_str(),
+                llama_detokenize_bpe(ctx, test_kv.second).c_str());
+            fprintf(stderr, "%s : expected tokens: ", __func__);
+            for (const auto & t : test_kv.second) {
+                fprintf(stderr, "%6d, ", t);
+            }
+            fprintf(stderr, "\n");
+            fprintf(stderr, "%s : got tokens:      ", __func__);
+            for (const auto & t : res) {
+                fprintf(stderr, "%6d, ", t);
+            }
+            fprintf(stderr, "\n");
+
+            success = false;
+        }
+    }
+
+    if (!fname_text.empty()) {
+        fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str());
+
+        std::string text;
+        {
+            std::ifstream ifs(fname_text);
+            if (!ifs) {
+                fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_text.c_str());
+                return 1;
+            }
+            text = std::string(std::istreambuf_iterator<char>(ifs), std::istreambuf_iterator<char>());
+        }
+
+        fprintf(stderr, "%s : text size: %zu\n", __func__, text.size());
+
+        const std::vector<llama_token> res = llama_tokenize(ctx, text, true);
+
+        fprintf(stderr, "%s : tokens: %zu\n", __func__, res.size());
+
+        {
+            const std::string fname_out = fname_text + ".tokcpp";
+
+            std::ofstream ofs(fname_out);
+            if (!ofs) {
+                fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_out.c_str());
+                return 1;
+            }
+
+            for (const auto & tok : res) {
+                ofs << tok << " ";
+            }
+
+            ofs << "\n";
+        }
+
+        fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str());
+    }
+
+    llama_free_model(model);
+    llama_free(ctx);
+
+    llama_backend_free();
+
+    return success ? 0 : 3;
+}
diff --git a/tests/test-tokenizer-0-falcon.py b/tests/test-tokenizer-0-falcon.py
new file mode 100644 (file)
index 0000000..9c8c1c7
--- /dev/null
@@ -0,0 +1,83 @@
+# tests with BPE tokenizer
+
+import os
+import sys
+import argparse
+
+from transformers import AutoTokenizer
+
+parser = argparse.ArgumentParser()
+parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
+parser.add_argument("--fname-tok",   help="path to a text file to tokenize")
+args = parser.parse_args()
+
+dir_tokenizer = args.dir_tokenizer
+
+tokenizer = AutoTokenizer.from_pretrained(dir_tokenizer)
+
+tests = [
+        "",
+        " ",
+        "  ",
+        "   ",
+        "\t",
+        "\n",
+        "\t\n",
+        "Hello world",
+        " Hello world",
+        "Hello World",
+        " Hello World",
+        " Hello World!",
+        "Hello, world!",
+        " Hello, world!",
+        " this is πŸ¦™.cpp",
+        "w048 7tuijk dsdfhu",
+        "Π½Π΅Ρ‰ΠΎ Π½Π° Π‘ΡŠΠ»Π³Π°Ρ€ΡΠΊΠΈ",
+        "αž€αžΆαž“αŸ‹αžαŸ‚αž–αž·αžŸαŸαžŸαž’αžΆαž…αžαž›αž…αŸαž‰",
+        "πŸš€ (normal) πŸ˜Άβ€πŸŒ«οΈ (multiple emojis concatenated) βœ… (only emoji that has its own token)",
+        "Hello",
+        " Hello",
+        "  Hello",
+        "   Hello",
+        "    Hello",
+        "    Hello\n    Hello",
+    ]
+
+for text in tests:
+    print('text: ', text)
+    print(tokenizer.encode(text))
+    print(tokenizer.decode(tokenizer.encode(text)))
+
+print("\n\ntests for C++:\n")
+for text in tests:
+    res = tokenizer.encode(text)
+
+    k = text.replace('\n', '\\n')
+    k = k.replace('\t', '\\t')
+    k = '"' + k + '"'
+    print("{ %-24s, { " % k, end='')
+    for x in res:
+        print("%7d," % x, end='')
+    print(" }, },")
+
+print(tokenizer.encode('hello'))
+print(tokenizer.encode('world'))
+print(tokenizer.encode(' world'))
+print(tokenizer.encode('hello world'))
+
+fname_tok = args.fname_tok
+if fname_tok:
+    print('tokenizing file: ', fname_tok)
+    fname_out = fname_tok + '.tok'
+    with open(fname_tok, 'r') as f:
+        lines = f.readlines()
+        s = ''.join(lines)
+        res = tokenizer.encode(s)
+        # write to file
+        with open(fname_out, 'w') as f:
+            for x in res:
+                f.write(str(x) + ' ')
+            f.write('\n')
+        print('len(res): ', len(res))
+        print('len(lines): ', len(lines))
+    print('results written to: ', fname_out)
diff --git a/tests/test-tokenizer-0-llama.cpp b/tests/test-tokenizer-0-llama.cpp
new file mode 100644 (file)
index 0000000..8630742
--- /dev/null
@@ -0,0 +1,182 @@
+#include "llama.h"
+#include "common.h"
+
+#include <cstdio>
+#include <string>
+#include <map>
+#include <vector>
+#include <fstream>
+
+// generate using test-tokenizer-0-llama.py
+static const std::map<std::string, std::vector<llama_token>> & k_tests() {
+    static std::map<std::string, std::vector<llama_token>> _k_tests = {
+        { ""                      , {  }, },
+        { " "                     , {     259, }, },
+        { "  "                    , {    1678, }, },
+        { "   "                   , {     268, }, },
+        { "\t"                    , {   29871,     12, }, },
+        { "\n"                    , {   29871,     13, }, },
+        { "\t\n"                  , {   29871,     12,     13, }, },
+        { "Hello world"           , {   15043,   3186, }, },
+        { " Hello world"          , {   29871,  15043,   3186, }, },
+        { "Hello World"           , {   15043,   2787, }, },
+        { " Hello World"          , {   29871,  15043,   2787, }, },
+        { " Hello World!"         , {   29871,  15043,   2787,  29991, }, },
+        { "Hello, world!"         , {   15043,  29892,   3186,  29991, }, },
+        { " Hello, world!"        , {   29871,  15043,  29892,   3186,  29991, }, },
+        { " this is πŸ¦™.cpp"        , {   29871,    445,    338,  29871,    243,    162,    169,    156,  29889,   8223, }, },
+        { "w048 7tuijk dsdfhu"    , {     281,  29900,  29946,  29947,  29871,  29955,   9161,  13535,  18031,   2176,   6905, }, },
+        { "Π½Π΅Ρ‰ΠΎ Π½Π° Π‘ΡŠΠ»Π³Π°Ρ€ΡΠΊΠΈ"     , {    1538,   4851,    665,   1386,  29713,   1305, }, },
+        { "αž€αžΆαž“αŸ‹αžαŸ‚αž–αž·αžŸαŸαžŸαž’αžΆαž…αžαž›αž…αŸαž‰"   , {   29871,  31849,  31324,  31934,    228,    162,    142,    228,    161,    146,    228,    162,    133,    228,    161,    153,    228,    161,    186,  31708,    228,    162,    132,  31708,    228,    161,    165,  31324,    228,    161,    136,    228,    161,    132,    228,    161,    158,    228,    161,    136,    228,    162,    132,    228,    161,    140, }, },
+        { "πŸš€ (normal) πŸ˜Άβ€πŸŒ«οΈ (multiple emojis concatenated) βœ… (only emoji that has its own token)", {   29871,    243,    162,    157,    131,    313,   8945,  29897,  29871,    243,    162,    155,    185,  30722,    243,    162,    143,    174,  30598,    313,  20787,    953,   3848,    275,  16125,    630,  29897,  29871,  31681,    313,   6194,    953,  29877,   2397,    393,    756,    967,   1914,   5993,  29897, }, },
+        { "Hello"                 , {   15043, }, },
+        { " Hello"                , {   29871,  15043, }, },
+        { "  Hello"               , {     259,  15043, }, },
+        { "   Hello"              , {    1678,  15043, }, },
+        { "    Hello"             , {     268,  15043, }, },
+        { "    Hello\n    Hello"  , {     268,  15043,     13,   1678,  15043, }, },
+    };
+
+    return _k_tests;
+}
+
+int main(int argc, char **argv) {
+    if (argc < 2) {
+        fprintf(stderr, "Usage: %s vocab-file [text-file]\n", argv[0]);
+        return 1;
+    }
+
+    const std::string fname = argv[1];
+
+    std::string fname_text;
+    if (argc > 2) {
+        fname_text = argv[2];
+    }
+
+    fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
+
+    llama_model * model;
+    llama_context * ctx;
+
+    llama_backend_init(false);
+
+    // load the vocab
+    {
+        auto lparams = llama_context_default_params();
+
+        lparams.vocab_only = true;
+
+        model = llama_load_model_from_file(fname.c_str(), lparams);
+
+        if (model == NULL) {
+            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
+            return 1;
+        }
+
+        ctx = llama_new_context_with_model(model, lparams);
+
+        if (ctx == NULL) {
+            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
+            llama_free_model(model);
+            return 1;
+        }
+    }
+
+    if (llama_vocab_type(ctx) != LLAMA_VOCAB_TYPE_SPM) {
+        fprintf(stderr, "%s : error: vocab type is not SPM\n", __func__);
+        llama_free_model(model);
+        llama_free(ctx);
+        return 2;
+    }
+
+    bool success = true;
+
+    for (const auto & test_kv : k_tests()) {
+        const std::vector<llama_token> res_bos   = llama_tokenize(ctx, test_kv.first, true);
+        const std::vector<llama_token> res_nobos = llama_tokenize(ctx, test_kv.first, false);
+
+        printf("\n");
+        printf("src: '%s'\n", test_kv.first.c_str());
+        printf("res: '%s'\n", llama_detokenize_spm(ctx, res_bos).c_str());
+        printf("tok: ");
+        for (const auto & tok : res_bos) {
+            printf("%d ", tok);
+        }
+        printf("\n");
+
+        bool correct = res_nobos.size() == test_kv.second.size() && res_bos.size() == res_nobos.size() + 1 && res_bos[0] == 1;
+
+        for (int i = 0; i < (int) res_nobos.size() && correct; ++i) {
+            if (test_kv.second[i] != res_bos[i + 1]) {
+                correct = false;
+            }
+            if (test_kv.second[i] != res_nobos[i]) {
+                correct = false;
+            }
+        }
+
+        if (!correct) {
+            fprintf(stderr, "%s : failed test:    '%s'\n", __func__, test_kv.first.c_str());
+            fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
+                llama_detokenize_spm(ctx, res_nobos).c_str(),
+                llama_detokenize_spm(ctx, test_kv.second).c_str());
+            fprintf(stderr, "%s : expected tokens: ", __func__);
+            for (const auto & t : test_kv.second) {
+                fprintf(stderr, "%6d, ", t);
+            }
+            fprintf(stderr, "\n");
+            fprintf(stderr, "%s : got tokens:      ", __func__);
+            for (const auto & t : res_nobos) {
+                fprintf(stderr, "%6d, ", t);
+            }
+            fprintf(stderr, "\n");
+
+            success = false;
+        }
+    }
+
+    if (!fname_text.empty()) {
+        fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str());
+
+        std::string text;
+        {
+            std::ifstream ifs(fname_text);
+            if (!ifs) {
+                fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_text.c_str());
+                return 1;
+            }
+            text = std::string(std::istreambuf_iterator<char>(ifs), std::istreambuf_iterator<char>());
+        }
+
+        fprintf(stderr, "%s : text size: %zu\n", __func__, text.size());
+
+        const std::vector<llama_token> res = llama_tokenize(ctx, text, true);
+
+        fprintf(stderr, "%s : tokens: %zu\n", __func__, res.size());
+
+        {
+            const std::string fname_out = fname_text + ".tokcpp";
+
+            std::ofstream ofs(fname_out);
+            if (!ofs) {
+                fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_out.c_str());
+                return 1;
+            }
+
+            for (const auto & tok : res) {
+                ofs << tok << " ";
+            }
+
+            ofs << "\n";
+        }
+
+        fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str());
+    }
+
+    llama_free_model(model);
+    llama_free(ctx);
+
+    llama_backend_free();
+
+    return success ? 0 : 3;
+}
diff --git a/tests/test-tokenizer-0-llama.py b/tests/test-tokenizer-0-llama.py
new file mode 100644 (file)
index 0000000..bc164ee
--- /dev/null
@@ -0,0 +1,95 @@
+# tests with SPM tokenizer
+
+import os
+import sys
+import argparse
+
+from sentencepiece import SentencePieceProcessor
+
+parser = argparse.ArgumentParser()
+parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
+parser.add_argument("--fname-tok",   help="path to a text file to tokenize")
+args = parser.parse_args()
+
+dir_tokenizer = args.dir_tokenizer
+
+tokenizer = SentencePieceProcessor(dir_tokenizer + '/tokenizer.model')
+
+tests = [
+        "",
+        " ",
+        "  ",
+        "   ",
+        "\t",
+        "\n",
+        "\t\n",
+        "Hello world",
+        " Hello world",
+        "Hello World",
+        " Hello World",
+        " Hello World!",
+        "Hello, world!",
+        " Hello, world!",
+        " this is πŸ¦™.cpp",
+        "w048 7tuijk dsdfhu",
+        "Π½Π΅Ρ‰ΠΎ Π½Π° Π‘ΡŠΠ»Π³Π°Ρ€ΡΠΊΠΈ",
+        "αž€αžΆαž“αŸ‹αžαŸ‚αž–αž·αžŸαŸαžŸαž’αžΆαž…αžαž›αž…αŸαž‰",
+        "πŸš€ (normal) πŸ˜Άβ€πŸŒ«οΈ (multiple emojis concatenated) βœ… (only emoji that has its own token)",
+        "Hello",
+        " Hello",
+        "  Hello",
+        "   Hello",
+        "    Hello",
+        "    Hello\n    Hello",
+    ]
+
+
+for text in tests:
+    print('text: ', text)
+    print('\nwith bos:')
+    print(tokenizer.encode(text, add_bos=True))
+    print(tokenizer.decode(tokenizer.encode(text, add_bos=True)))
+    print('\nwithout bos:')
+    print(tokenizer.encode(text, add_bos=False))
+    print(tokenizer.decode(tokenizer.encode(text, add_bos=False)))
+
+print("'" + tokenizer.id_to_piece(15043) + "'") # '_Hello'
+print("'" + tokenizer.id_to_piece(29871) + "'") # '_'
+print("'" + tokenizer.decode([15043]) + "'")        # 'Hello'
+print("'" + tokenizer.decode([15043, 15043]) + "'") # 'Hello Hello'
+print("'" + tokenizer.decode([29871, 15043]) + "'")               # ' Hello'
+print("'" + tokenizer.decode([29871, 15043, 29871, 15043]) + "'") # ' Hello  Hello'
+
+print("\n\ntests for C++:\n")
+for text in tests:
+    res = tokenizer.encode(text, add_bos=False)
+
+    k = text.replace('\n', '\\n')
+    k = k.replace('\t', '\\t')
+    k = '"' + k + '"'
+    print("{ %-24s, { " % k, end='')
+    for x in res:
+        print("%7d," % x, end='')
+    print(" }, },")
+
+print(tokenizer.encode('hello'))
+print(tokenizer.encode('world'))
+print(tokenizer.encode(' world'))
+print(tokenizer.encode('hello world'))
+
+fname_tok = args.fname_tok
+if fname_tok:
+    print('tokenizing file: ', fname_tok)
+    fname_out = fname_tok + '.tok'
+    with open(fname_tok, 'r') as f:
+        lines = f.readlines()
+        s = ''.join(lines)
+        res = tokenizer.encode(s, add_bos=True)
+        # write to file
+        with open(fname_out, 'w') as f:
+            for x in res:
+                f.write(str(x) + ' ')
+            f.write('\n')
+        print('len(res): ', len(res))
+        print('len(lines): ', len(lines))
+    print('results written to: ', fname_out)
diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp
deleted file mode 100644 (file)
index 7e9ac91..0000000
+++ /dev/null
@@ -1,141 +0,0 @@
-#include "llama.h"
-#include "common.h"
-
-#include <cstdio>
-#include <string>
-#include <map>
-#include <vector>
-
-static std::string unescape_whitespace(llama_context* ctx, const std::vector<llama_token>& tokens) {
-    std::string result;
-    for (size_t i = 0; i < tokens.size(); ++i) {
-        result += llama_token_to_str(ctx, tokens[i]);
-    }
-    return result;
-}
-
-static const std::map<std::string, std::vector<llama_token>> & k_tests() {
-    static std::map<std::string, std::vector<llama_token>> _k_tests = {
-        { " ",                      {1,    259, }, },
-        { "  ",                     { 1,    1678, }, },
-        { "   ",                    { 1,     268, }, },
-        { "\t",                     { 1,    29871,   12, }, },
-        { "\n",                     { 1,    29871,   13, }, },
-        { "\t\n",                   { 1,    29871,   12,     13, }, },
-        { "Hello world",            { 1,  15043,   3186, }, },
-        { " Hello world",           { 1,  29871,  15043,   3186, }, },
-        { "Hello World",            { 1,  15043,   2787, }, },
-        { " Hello World",           { 1,  29871,  15043,   2787, }, },
-        { " Hello World!",          { 1,  29871,  15043,   2787,  29991, }, },
-        { " this is πŸ¦™.cpp",        { 1,  29871,    445,    338,  29871,    243,    162,    169,    156,  29889,   8223, }, },
-        { "w048 7tuijk dsdfhu",     { 1,    281,  29900,  29946,  29947,  29871,  29955,   9161,  13535,  18031,   2176,   6905, }, },
-        { "Π½Π΅Ρ‰ΠΎ Π½Π° Π‘ΡŠΠ»Π³Π°Ρ€ΡΠΊΠΈ",      { 1,   1538,   4851,    665,   1386,  29713,   1305, }, },
-        { "αž€αžΆαž“αŸ‹αžαŸ‚αž–αž·αžŸαŸαžŸαž’αžΆαž…αžαž›αž…αŸαž‰",   { 1,  29871,  31849,  31324,  31934,    228,    162,    142,    228,    161,
-                                     146,    228,    162,    133,    228,    161,    153,    228,    161,    186,
-                                     31708,    228,    162,    132,  31708,    228,    161,    165,  31324,    228,
-                                     161,    136,    228,    161,    132,    228,    161,    158,    228,    161,
-                                     136,    228,    162,    132,    228,    161,    140, }, },
-        { "πŸš€ (normal) πŸ˜Άβ€πŸŒ«οΈ (multiple emojis concatenated) βœ… (only emoji that has its own token)",
-            { 1,  29871,    243,    162,    157,    131,    313,   8945,  29897,  29871,
-                243,    162,    155,    185,  30722,    243,    162,    143,    174,  30598,
-                313,  20787,    953,   3848,    275,  16125,    630,  29897,  29871,  31681,
-                313,   6194,    953,  29877,   2397,    393,    756,    967,   1914,   5993,  29897, }, },
-        { "Hello",                  { 1,    15043 }, },
-        { " Hello",                 { 1,    29871,  15043 }, },
-        { "  Hello",                { 1,    259,    15043 }, },
-        { "   Hello",               { 1,    1678,   15043 }, },
-        { "    Hello",              { 1,    268,    15043 }, },
-        { "    Hello\n    Hello",   { 1,    268,    15043,  13,     1678,   15043 }, },
-    };
-
-    return _k_tests;
-}
-
-int main(int argc, char **argv) {
-    if (argc < 2) {
-        fprintf(stderr, "Usage: %s <vocab-file>\n", argv[0]);
-        return 1;
-    }
-
-    const std::string fname = argv[1];
-
-    fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
-
-    llama_model * model;
-    llama_context * ctx;
-
-    llama_backend_init(false);
-
-    // load the vocab
-    {
-        auto lparams = llama_context_default_params();
-
-        lparams.vocab_only = true;
-
-        model = llama_load_model_from_file(fname.c_str(), lparams);
-
-        if (model == NULL) {
-            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
-            return 1;
-        }
-
-        ctx = llama_new_context_with_model(model, lparams);
-
-        if (ctx == NULL) {
-            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
-            llama_free_model(model);
-            return 1;
-        }
-    }
-
-    const int n_vocab = llama_n_vocab(ctx);
-
-    if (n_vocab != 32000) {
-        fprintf(stderr, "%s : expected 32000 tokens, got %d\n", __func__, n_vocab);
-        llama_free_model(model);
-        llama_free(ctx);
-        return 2;
-    }
-
-    bool success = true;
-
-    for (const auto & test_kv : k_tests()) {
-        // Add a space in front of the first character to match OG llama tokenizer behavior
-        std::vector<llama_token> res = llama_tokenize(ctx, " " + test_kv.first, true);
-        fprintf(stderr, "%s : '%s' tokenized to '%s'\n",
-            __func__, test_kv.first.c_str(), unescape_whitespace(ctx, res).c_str());
-
-        bool correct = res.size() == test_kv.second.size();
-
-        for (int i = 0; i < (int) res.size() && correct; ++i) {
-            if (res[i] != test_kv.second[i]) {
-                correct = false;
-            }
-        }
-
-        if (!correct) {
-            fprintf(stderr, "%s : failed test:    '%s'\n", __func__, test_kv.first.c_str());
-            fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
-                unescape_whitespace(ctx, res).c_str(), unescape_whitespace(ctx, test_kv.second).c_str());
-            fprintf(stderr, "%s : expected tokens: ", __func__);
-            for (const auto & t : test_kv.second) {
-                fprintf(stderr, "%6d, ", t);
-            }
-            fprintf(stderr, "\n");
-            fprintf(stderr, "%s : got tokens:      ", __func__);
-            for (const auto & t : res) {
-                fprintf(stderr, "%6d, ", t);
-            }
-            fprintf(stderr, "\n");
-
-            success = false;
-        }
-    }
-
-    llama_free_model(model);
-    llama_free(ctx);
-
-    llama_backend_free();
-
-    return success ? 0 : 3;
-}
index bd607d12bb1cd8dd9293aa1959fc381749a2153b..ce4f2898ce49abd0a432241fd1ff0d4c95c20b03 100644 (file)
@@ -22,14 +22,6 @@ static std::string escape_whitespace(const std::string& text) {
     return result;
 }
 
-static std::string unescape_whitespace(llama_context * ctx, const std::vector<llama_token> & tokens) {
-    std::string result;
-    for (size_t i = 0; i < tokens.size(); ++i) {
-        result += llama_token_to_str(ctx, tokens[i]);
-    }
-    return result;
-}
-
 int main(int argc, char **argv) {
     if (argc < 2) {
         fprintf(stderr, "Usage: %s <vocab-file>\n", argv[0]);
@@ -72,13 +64,13 @@ int main(int argc, char **argv) {
     const int n_vocab = llama_n_vocab(ctx);
 
     for (int i = 0; i < n_vocab; ++i) {
-        std::string forward = llama_token_to_str(ctx, i);
+        std::string forward = llama_token_to_piece(ctx, i);
         std::vector<llama_token> tokens = llama_tokenize(ctx, forward, false);
         if (tokens.size() == 1) {
             if (i != tokens[0]) {
-                std::string backward = llama_token_to_str(ctx, tokens[0]);
+                std::string backward = llama_token_to_piece(ctx, tokens[0]);
                 fprintf(stderr, "%s : error: token %d is string %s but bpe returns token %d %s\n",
-                    __func__, i, llama_token_to_str(ctx, i).c_str(), tokens[0], backward.c_str());
+                    __func__, i, llama_token_to_piece(ctx, i).c_str(), tokens[0], backward.c_str());
                 return 2;
             }
         }