]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Fix spm whitespaces (#2806)
authorklosax <redacted>
Sat, 26 Aug 2023 11:45:53 +0000 (13:45 +0200)
committerGitHub <redacted>
Sat, 26 Aug 2023 11:45:53 +0000 (13:45 +0200)
* llama.cpp : fix spm whitespace escaping + clean up

* main.cpp : spm - add whitespace in front of prompt

* test-tokenizer-0.cpp : spm - add whitespace in front of prompt

examples/main/main.cpp
llama.cpp
tests/test-tokenizer-0.cpp

index cb8747c2b74f1728049ad4a6b2f2cd6761a2b199..4665b82fe7f97f654e3e86200a5c8907ce35a4c6 100644 (file)
@@ -189,12 +189,19 @@ int main(int argc, char ** argv) {
         }
     }
 
-    const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
+    // Add BOS if SPM tokenizer
+    const bool add_bos = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
 
     // tokenize the prompt
     std::vector<llama_token> embd_inp;
+
+    if (llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM) {
+        // Add a space in front of the first character to match OG llama tokenizer behavior
+        params.prompt.insert(0, 1, ' ');
+    }
+
     if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
-        embd_inp = ::llama_tokenize(ctx, params.prompt, is_spm);
+        embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
     } else {
         embd_inp = session_tokens;
     }
@@ -210,9 +217,9 @@ int main(int argc, char ** argv) {
     int original_prompt_len = 0;
     if (ctx_guidance) {
         params.cfg_negative_prompt.insert(0, 1, ' ');
-        guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, is_spm);
+        guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos);
 
-        std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, is_spm);
+        std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
         original_prompt_len = original_inp.size();
         guidance_offset = (int)guidance_inp.size() - original_prompt_len;
     }
@@ -259,7 +266,7 @@ int main(int argc, char ** argv) {
     }
 
     // prefix & suffix for instruct mode
-    const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", is_spm);
+    const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos);
     const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n",    false);
 
     // in instruct mode, we inject a prefix and a suffix to each input by the user
index 7d8b9a0ac485b114227b24f56005e57b5484d518..b0a3b5768f3dd3ca58359cdd9674944d1c9c3682 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1635,7 +1635,7 @@ static void llm_load_hparams(
 }
 
 // TODO: This should probably be in llama.h
-static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape);
+static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos);
 
 static void llm_load_vocab(
         llama_model_loader & ml,
@@ -1737,7 +1737,7 @@ static void llm_load_vocab(
     }
 
     // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
-    vocab.linefeed_id = llama_tokenize_internal(vocab, "\n", false, false)[0];
+    vocab.linefeed_id = llama_tokenize_internal(vocab, "\n", false)[0];
 
     // special tokens
     GGUF_GET_KEY(ctx, vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID));
@@ -3027,14 +3027,8 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
 }
 
 static std::string llama_escape_whitespace(const std::string& text) {
-    std::string result = "\xe2\x96\x81";
-    for (size_t offs = 0; offs < text.length(); ++offs) {
-        if (text[offs] == ' ') {
-            result += "\xe2\x96\x81";
-        } else {
-            result += text[offs];
-        }
-    }
+    std::string result = text;
+    replace_all(result, " ", "\xe2\x96\x81");
     return result;
 }
 
@@ -3219,7 +3213,7 @@ struct llm_bigram_bpe {
 };
 
 struct llm_tokenizer_bpe {
-    llm_tokenizer_bpe(const llama_vocab & vocab, bool g2ws): vocab(vocab) { flag_g2ws = g2ws; }
+    llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {}
 
     void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
         int final_prev_index = -1;
@@ -3371,8 +3365,6 @@ private:
         return words;
     }
 
-    bool flag_g2ws = false;
-
     const llama_vocab & vocab;
 
     std::vector<llm_symbol> symbols;
@@ -3381,39 +3373,26 @@ private:
     llm_bigram_bpe::queue work_queue;
 };
 
-static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape) {
+static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos) {
     std::vector<llama_vocab::id> output;
 
     if (raw_text.empty()) {
         return output;
     }
 
+    if (bos && vocab.special_bos_id != -1) {
+        output.push_back(vocab.special_bos_id);
+    }
+
     switch (vocab.type) {
         case LLAMA_VOCAB_TYPE_SPM:
             {
                 llm_tokenizer_spm tokenizer(vocab);
-
-                if (bos) {
-                    output.push_back(vocab.special_bos_id);
-                }
-
-                std::string text;
-                if (escape) {
-                    text = llama_escape_whitespace(raw_text);
-                } else {
-                    text = raw_text;
-                }
-
-                tokenizer.tokenize(text, output);
+                tokenizer.tokenize(llama_escape_whitespace(raw_text), output);
             } break;
         case LLAMA_VOCAB_TYPE_BPE:
             {
-                llm_tokenizer_bpe tokenizer(vocab, escape);
-
-                if (bos && vocab.special_bos_id != -1) {
-                    output.push_back(vocab.special_bos_id);
-                }
-
+                llm_tokenizer_bpe tokenizer(vocab);
                 tokenizer.tokenize(raw_text, output);
             } break;
     };
@@ -6095,8 +6074,7 @@ int llama_tokenize_with_model(
                  llama_token * tokens,
                          int   n_max_tokens,
                         bool   add_bos) {
-    auto escape = llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM;
-    auto res = llama_tokenize_internal(model->vocab, text, add_bos, escape);
+    auto res = llama_tokenize_internal(model->vocab, text, add_bos);
 
     if (n_max_tokens < (int) res.size()) {
         LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
index f3ee851a3880ca029ef41458ea868cbc4a86776d..7e9ac9188d5c5b15f2e90c1e321ab51c1ac0e328 100644 (file)
@@ -100,7 +100,8 @@ int main(int argc, char **argv) {
     bool success = true;
 
     for (const auto & test_kv : k_tests()) {
-        std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, true);
+        // Add a space in front of the first character to match OG llama tokenizer behavior
+        std::vector<llama_token> res = llama_tokenize(ctx, " " + test_kv.first, true);
         fprintf(stderr, "%s : '%s' tokenized to '%s'\n",
             __func__, test_kv.first.c_str(), unescape_whitespace(ctx, res).c_str());