]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
BERT tokenizer fixes (#6498)
authorJared Van Bortel <redacted>
Tue, 9 Apr 2024 17:44:08 +0000 (13:44 -0400)
committerGitHub <redacted>
Tue, 9 Apr 2024 17:44:08 +0000 (13:44 -0400)
Key changes:
* BERT conversion: fix abuse of LlamaHfVocab, do not set BOS or EOS
* Nomic Embed conversion: pad vocab instead of slicing embedding tensor
* llama_tokenize: handle added special tokens like HF does

20 files changed:
common/common.cpp
common/common.h
convert-hf-to-gguf.py
convert-persimmon-to-gguf.py
convert.py
examples/embedding/embedding.cpp
examples/imatrix/imatrix.cpp
examples/infill/infill.cpp
examples/llava/llava-cli.cpp
examples/lookahead/lookahead.cpp
examples/lookup/lookup-create.cpp
examples/lookup/lookup-stats.cpp
examples/lookup/lookup.cpp
examples/main/main.cpp
examples/perplexity/perplexity.cpp
examples/server/server.cpp
examples/speculative/speculative.cpp
examples/tokenize/tokenize.cpp
llama.cpp
llama.h

index 7d983a453c68f826b2d5974366380af185eb729c..98fc8388cf3e0408b207a15eafa272269fef0ece 100644 (file)
@@ -2212,23 +2212,23 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
 std::vector<llama_token> llama_tokenize(
   const struct llama_context * ctx,
            const std::string & text,
-                        bool   add_bos,
-                        bool   special) {
-    return llama_tokenize(llama_get_model(ctx), text, add_bos, special);
+                        bool   add_special,
+                        bool   parse_special) {
+    return llama_tokenize(llama_get_model(ctx), text, add_special, parse_special);
 }
 
 std::vector<llama_token> llama_tokenize(
     const struct llama_model * model,
            const std::string & text,
-                        bool   add_bos,
-                        bool   special) {
+                        bool   add_special,
+                        bool   parse_special) {
     // upper limit for the number of tokens
-    int n_tokens = text.length() + add_bos;
+    int n_tokens = text.length() + 2 * add_special;
     std::vector<llama_token> result(n_tokens);
-    n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
+    n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
     if (n_tokens < 0) {
         result.resize(-n_tokens);
-        int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
+        int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
         GGML_ASSERT(check == -n_tokens);
     } else {
         result.resize(n_tokens);
index 4635e05d6381f01c89113b08266cffc57cdb5ab0..a7f476c1bf0cb9539fe0dd4618f8ea7bf3d265a4 100644 (file)
@@ -223,14 +223,14 @@ void llama_batch_add(
 std::vector<llama_token> llama_tokenize(
   const struct llama_context * ctx,
            const std::string & text,
-                        bool   add_bos,
-                        bool   special = false);
+                        bool   add_special,
+                        bool   parse_special = false);
 
 std::vector<llama_token> llama_tokenize(
     const struct llama_model * model,
            const std::string & text,
-                        bool   add_bos,
-                        bool   special = false);
+                        bool   add_special,
+                        bool   parse_special = false);
 
 // tokenizes a token into a piece
 // should work similar to Python's `tokenizer.id_to_piece`
index 37af6328a1705bae0ab6b5885f21f14f275d786d..63710676bad1c485bea4fbb951f2e1a200f914af 100755 (executable)
@@ -227,15 +227,14 @@ class Model(ABC):
             return ("pytorch_model.bin",)
         return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
 
-    def _set_vocab_gpt2(self):
-        dir_model = self.dir_model
-        hparams = self.hparams
+    # used for GPT-2 BPE and WordPiece vocabs
+    def get_basic_vocab(self) -> tuple[list[str], list[int]]:
         tokens: list[str] = []
         toktypes: list[int] = []
 
         from transformers import AutoTokenizer
-        tokenizer = AutoTokenizer.from_pretrained(dir_model)
-        vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
+        tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
+        vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
         assert max(tokenizer.vocab.values()) < vocab_size
 
         reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
@@ -255,11 +254,15 @@ class Model(ABC):
                 tokens.append(reverse_vocab[i])
                 toktypes.append(gguf.TokenType.NORMAL)
 
+        return tokens, toktypes
+
+    def _set_vocab_gpt2(self) -> None:
+        tokens, toktypes = self.get_basic_vocab()
         self.gguf_writer.add_tokenizer_model("gpt2")
         self.gguf_writer.add_token_list(tokens)
         self.gguf_writer.add_token_types(toktypes)
 
-        special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
+        special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
         special_vocab.add_to_gguf(self.gguf_writer)
 
     def _set_vocab_qwen(self):
@@ -2043,34 +2046,25 @@ class BertModel(Model):
             self.gguf_writer.add_pooling_type(pooling_type)
 
     def set_vocab(self):
-        # use huggingface vocab to get all tokens
-        vocab = LlamaHfVocab(self.dir_model, ignore_nonllama=True)
-        tokens, scores, toktypes = zip(*vocab.all_tokens())
-        assert len(tokens) == vocab.vocab_size
-        self.vocab_size = vocab.vocab_size
+        tokens, toktypes = self.get_basic_vocab()
+        self.vocab_size = len(tokens)
 
         # we need this to validate the size of the token_type embeddings
         # though currently we are passing all zeros to the token_type embeddings
-        n_token_types = len(set(toktypes))
-        self.gguf_writer.add_token_type_count(n_token_types)
+        self.gguf_writer.add_token_type_count(2)  # "Sequence A" or "Sequence B"
 
         # convert to phantom space vocab
-        def phantom(tok, typ):
-            if tok.startswith(b"[") and tok.endswith(b"]"):
+        def phantom(tok):
+            if tok.startswith("[") and tok.endswith("]"):
                 return tok
-            if tok.startswith(b"##"):
+            if tok.startswith("##"):
                 return tok[2:]
-            return b"\xe2\x96\x81" + tok
-        tokens = tuple(phantom(t, y) for t, y in zip(tokens, toktypes))
-
-        # set up bos and eos tokens (cls and sep)
-        self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
-        self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)
+            return "\u2581" + tok
+        tokens = list(map(phantom, tokens))
 
         # add vocab to gguf
         self.gguf_writer.add_tokenizer_model("bert")
         self.gguf_writer.add_token_list(tokens)
-        self.gguf_writer.add_token_scores(scores)
         self.gguf_writer.add_token_types(toktypes)
 
         # handle special tokens
@@ -2142,16 +2136,6 @@ class NomicBertModel(BertModel):
         super().set_gguf_parameters()
         self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
 
-    def get_tensors(self):
-        assert self.vocab_size is not None
-        for name, data in super().get_tensors():
-            # Nomic Embed's token embeddings tensor is padded, but llama.cpp wants tensor sizes to match exactly.
-            if name == 'embeddings.word_embeddings.weight' and data.shape[1] != self.vocab_size:
-                rounded_vocab_size = (self.vocab_size + 63) // 64 * 64
-                assert data.shape == (rounded_vocab_size, self.hparams["n_embd"])
-                data = data[:self.vocab_size, :]
-            yield name, data
-
 
 @Model.register("GemmaForCausalLM")
 class GemmaModel(Model):
@@ -2327,7 +2311,8 @@ class MambaModel(Model):
                 data = data.astype(np.float32)
 
             # if f16 desired, convert big float32 2-dim weight tensors to float16
-            if self.ftype == 1 and data_dtype == np.float32 and new_name.removesuffix(".weight").endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
+            new_weight_name = new_name[:-len(".weight")] if new_name.endswith(".weight") else ""
+            if self.ftype == 1 and data_dtype == np.float32 and new_weight_name.endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
                 data = data.astype(np.float16)
 
             print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
index ccb99279e20a8c2bf441dbcfd7e5d292b270afa5..69be17f94efd9d87718f362270edd25ec6617e52 100755 (executable)
@@ -1,4 +1,6 @@
 #!/usr/bin/env python3
+from __future__ import annotations
+
 import argparse
 import os
 import sys
index a37aeb5e5a652a5c36c5eddf805c262c41b895a0..e860ac89fc2776eb1a393b4eb613aa8e4e66f817 100755 (executable)
@@ -33,7 +33,7 @@ if 'NO_LOCAL_GGUF' not in os.environ:
 import gguf
 
 if TYPE_CHECKING:
-    from typing import TypeAlias
+    from typing_extensions import Self, TypeAlias
 
 if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
     faulthandler.register(signal.SIGUSR1)
@@ -517,7 +517,7 @@ class LlamaHfVocab(Vocab):
     tokenizer_model = "llama"
     name = "hfft"
 
-    def __init__(self, base_path: Path, ignore_nonllama: bool = False):
+    def __init__(self, base_path: Path):
         fname_tokenizer = base_path / FAST_TOKENIZER_FILE
         # if this fails, FileNotFoundError propagates to caller
         with open(fname_tokenizer, encoding='utf-8') as f:
@@ -525,9 +525,7 @@ class LlamaHfVocab(Vocab):
 
         # pre-check so we know if we need transformers
         tokenizer_model: dict[str, Any] = tokenizer_json['model']
-        if ignore_nonllama:
-            pass  # workaround incorrect use of this class for WordPiece
-        elif (
+        if (
             tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
             or tokenizer_json['decoder']['type'] != 'Sequence'
         ):
@@ -647,16 +645,17 @@ def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray:
 
 
 class Tensor(ABC):
+    ndarray: NDArray
     data_type: DataType
 
     @abstractmethod
-    def astype(self, data_type: DataType) -> Tensor: ...
+    def astype(self, data_type: DataType) -> Self: ...
     @abstractmethod
-    def permute(self, n_head: int, n_head_kv: int) -> Tensor: ...
+    def permute(self, n_head: int, n_head_kv: int) -> Self: ...
     @abstractmethod
-    def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: ...
+    def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> Self: ...
     @abstractmethod
-    def part(self, n_part: int) -> UnquantizedTensor: ...
+    def part(self, n_part: int) -> Self: ...
     @abstractmethod
     def to_ggml(self) -> GGMLCompatibleTensor: ...
 
@@ -673,13 +672,13 @@ class UnquantizedTensor(Tensor):
         self.ndarray = ndarray
         self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
 
-    def astype(self, data_type: DataType) -> Tensor:
+    def astype(self, data_type: DataType) -> UnquantizedTensor:
         dtype = data_type.dtype
         if self.data_type == DT_BF16:
             self.ndarray = bf16_to_fp32(self.ndarray)
         return UnquantizedTensor(self.ndarray.astype(dtype))
 
-    def to_ggml(self) -> UnquantizedTensor:
+    def to_ggml(self) -> Self:
         return self
 
     def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor:
index 536657526685ca34217b0118988df6f625a5ecdf..6a93147d70e88959f66766bcadd88a08578ead6d 100644 (file)
@@ -123,10 +123,10 @@ int main(int argc, char ** argv) {
         inputs.push_back(inp);
     }
 
-    // add eos if not present
+    // add SEP if not present
     for (auto & inp : inputs) {
-        if (inp.empty() || inp.back() != llama_token_eos(model)) {
-            inp.push_back(llama_token_eos(model));
+        if (inp.empty() || inp.back() != llama_token_sep(model)) {
+            inp.push_back(llama_token_sep(model));
         }
     }
 
index d8cb0a6420456547229acf8aaca0398ecd1b4282..1bf55f90c0f49df43591e6a5f1fec98fd3a23b57 100644 (file)
@@ -349,12 +349,13 @@ static void process_logits(
 static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl, int from_chunk) {
 
     const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
+    GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1);
     const int n_ctx = llama_n_ctx(ctx);
 
     auto tim1 = std::chrono::high_resolution_clock::now();
     fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
 
-    std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
+    std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true);
 
     auto tim2 = std::chrono::high_resolution_clock::now();
     fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
index 91c39c5ae42e35dd4ec6675108470a184398a9a3..c69dcd06e461fee9b1b59ffcf093ccfd18babdf4 100644 (file)
@@ -239,6 +239,7 @@ int main(int argc, char ** argv) {
         LOG_TEE("%s\n", get_system_info(params).c_str());
     }
     const bool add_bos = llama_should_add_bos_token(model);
+    GGML_ASSERT(llama_add_eos_token(model) != 1);
     LOG("add_bos: %d\n", add_bos);
 
     bool suff_rm_leading_spc = params.escape;
@@ -279,10 +280,10 @@ int main(int argc, char ** argv) {
     if (ctx_guidance) {
         LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
 
-        guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos);
+        guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true);
         LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
 
-        std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
+        std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true);
         LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
 
         original_prompt_len = original_inp.size();
index e29da6cb2f9b10edb048fc7a58c1cb8e90ff0de4..75948806ee5d4c878118f399a759b07b9f20d00c 100644 (file)
@@ -146,7 +146,6 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
     int n_past = 0;
 
     const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
-    const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx_llava->ctx_llama));
 
     std::string system_prompt, user_prompt;
     size_t image_pos = prompt.find("<image>");
@@ -180,7 +179,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
         }
     }
 
-    eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, add_bos);
+    eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, true);
     llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past);
     eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
 
index e2551e7a494c21f5a09d31c03f1169bef315f9cd..5af6a8ab6c92bc9422350dd3b443b206c72594ee 100644 (file)
@@ -64,13 +64,10 @@ int main(int argc, char ** argv) {
     std::tie(model, ctx) = llama_init_from_gpt_params(params);
 
     // Tokenize the prompt
-    const bool add_bos = llama_should_add_bos_token(model);
-    LOG("add_bos tgt: %d\n", add_bos);
-
     std::vector<llama_token> inp;
     std::vector<llama_token> all;
 
-    inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
+    inp = ::llama_tokenize(ctx, params.prompt, true, true);
     all = inp;
 
     const int max_context_size     = llama_n_ctx(ctx);
index 46a6bed078e7129888d41cc960cb202e931cb528..1c230c9667c715dd7c64c207c48b91b4bd9508a4 100644 (file)
@@ -28,10 +28,8 @@ int main(int argc, char ** argv){
     GGML_ASSERT(model != nullptr);
 
     // tokenize the prompt
-    const bool add_bos = llama_should_add_bos_token(model);
-
     std::vector<llama_token> inp;
-    inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
+    inp = ::llama_tokenize(ctx, params.prompt, true, true);
     fprintf(stderr, "%s: tokenization done\n", __func__);
 
 
index 31f227773de6f0f1e50fc923a79d736b5ca811cd..41b62c2fe9f76b4966717ae5fe665aa9571cbc83 100644 (file)
@@ -34,11 +34,8 @@ int main(int argc, char ** argv){
     GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
 
     // tokenize the prompt
-    const bool add_bos = llama_should_add_bos_token(model);
-    LOG("add_bos tgt: %d\n", add_bos);
-
     std::vector<llama_token> inp;
-    inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
+    inp = ::llama_tokenize(ctx, params.prompt, true, true);
 
     llama_ngram_cache ngram_cache_context;
     llama_ngram_cache ngram_cache_dynamic;
index 2e8c35de31427d5ebaab1f446dc604e55f4660c2..65ed408a2758394f03508113f97319cd01d84ed0 100644 (file)
@@ -42,11 +42,8 @@ int main(int argc, char ** argv){
     GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
 
     // tokenize the prompt
-    const bool add_bos = llama_should_add_bos_token(model);
-    LOG("add_bos tgt: %d\n", add_bos);
-
     std::vector<llama_token> inp;
-    inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
+    inp = ::llama_tokenize(ctx, params.prompt, true, true);
 
     llama_ngram_cache ngram_cache_context;
     llama_ngram_cache ngram_cache_dynamic;
index 711f162d79fca311175c78d50dbd39c8b6e15166..249fc2bb605b36fc5f9184a7ce4475d4f8740ca7 100644 (file)
@@ -246,6 +246,7 @@ int main(int argc, char ** argv) {
     }
 
     const bool add_bos = llama_should_add_bos_token(model);
+    GGML_ASSERT(llama_add_eos_token(model) != 1);
     LOG("add_bos: %d\n", add_bos);
 
     std::vector<llama_token> embd_inp;
@@ -255,7 +256,7 @@ int main(int argc, char ** argv) {
         if (params.chatml) {
             params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>";
         }
-        embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
+        embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
     } else {
         LOG("use session tokens\n");
         embd_inp = session_tokens;
@@ -277,10 +278,10 @@ int main(int argc, char ** argv) {
     if (ctx_guidance) {
         LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
 
-        guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos, true);
+        guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true, true);
         LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
 
-        std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
+        std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
         LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
 
         original_prompt_len = original_inp.size();
@@ -339,14 +340,14 @@ int main(int argc, char ** argv) {
     }
 
     // prefix & suffix for instruct mode
-    const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true);
-    const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n",    false,   true);
+    const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true,  true);
+    const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n",    false, true);
 
     LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
     LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());
 
     // chatml prefix & suffix
-    const auto cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", add_bos, true);
+    const auto cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", true, true);
     const auto cml_sfx = ::llama_tokenize(ctx, "<|im_end|>\n<|im_start|>assistant\n", false, true);
 
     LOG("cml_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_pfx).c_str());
index c70385c62bb07154f0691e052158d8818f0f63fb..bab79aaea89cabc470bb586eff403496e57369d9 100644 (file)
@@ -315,10 +315,11 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
     // BOS tokens will be added for each chunk before eval
 
     const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
+    GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1);
 
     fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
 
-    std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
+    std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true);
 
     const int n_ctx = llama_n_ctx(ctx);
 
@@ -454,6 +455,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
     // BOS tokens will be added for each chunk before eval
 
     const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
+    GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1);
 
     std::ofstream logits_stream;
     if (!params.logits_file.empty()) {
@@ -470,7 +472,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
     auto tim1 = std::chrono::high_resolution_clock::now();
     fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
 
-    std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
+    std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true);
 
     auto tim2 = std::chrono::high_resolution_clock::now();
     fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
@@ -771,9 +773,6 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
     const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
     fprintf(stderr, "================================= is_spm = %d\n", is_spm);
 
-    // This is needed as usual for LLaMA models
-    const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
-
     // The tasks should be randomized so the score stabilizes quickly.
     bool randomize_tasks = true;
 
@@ -818,7 +817,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
         hs_cur.gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
         for (size_t j = 0; j < 4; j++) {
             hs_cur.ending[j] = prompt_lines[idx*6+2+j];
-            hs_cur.seq_tokens[j] = ::llama_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], add_bos);
+            hs_cur.seq_tokens[j] = ::llama_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], true);
         }
 
         // determine the common prefix of the endings
@@ -837,7 +836,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
             hs_cur.seq_tokens[2].size() - hs_cur.common_prefix +
             hs_cur.seq_tokens[3].size() - hs_cur.common_prefix;
 
-        //GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, add_bos).size());
+        //GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, true).size());
 
         // Delete the selected random example from the prompt
         if (randomize_tasks) {
@@ -1110,12 +1109,9 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
 
     fprintf(stderr, "%s : tokenizing selected tasks\n", __func__);
 
-    // This is needed as usual for LLaMA models
-    const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
-
     for (auto & task : data) {
-        task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, add_bos);
-        task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, add_bos);
+        task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, true);
+        task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, true);
 
         task.common_prefix = 0;
         for (size_t k = 0; k < task.seq_tokens[0].size(); k++) {
@@ -1130,8 +1126,8 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
             task.seq_tokens[0].size() - task.common_prefix +
             task.seq_tokens[1].size() - task.common_prefix;
 
-        task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos).size();
-        task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos).size();
+        task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], true).size();
+        task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], true).size();
     }
 
     fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__);
@@ -1322,7 +1318,7 @@ struct multiple_choice_task {
     std::vector<float> log_probs;
 };
 
-static bool multiple_choice_prepare_one_task(llama_context * ctx, bool add_bos, multiple_choice_task& task, bool log_error) {
+static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choice_task& task, bool log_error) {
     if (task.question.empty() || task.mc1.answers.empty()) {
         if (log_error) {
             printf("%s: found bad task with empty question and/or answers\n", __func__);
@@ -1337,7 +1333,7 @@ static bool multiple_choice_prepare_one_task(llama_context * ctx, bool add_bos,
             }
             return false;
         }
-        task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, add_bos));
+        task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, true));
     }
     auto min_len = task.seq_tokens.front().size();
     for (auto& seq : task.seq_tokens) {
@@ -1436,9 +1432,6 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
         n_task = params.multiple_choice_tasks;
     }
 
-    // This is needed as usual for LLaMA models
-    const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
-
     printf("%s: preparing task data", __func__);
     fflush(stdout);
     if (n_task > 500) {
@@ -1446,7 +1439,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
         fflush(stdout);
         std::atomic<int> counter(0);
         std::atomic<int> n_bad(0);
-        auto prepare = [&counter, &n_bad, &tasks, ctx, add_bos] () {
+        auto prepare = [&counter, &n_bad, &tasks, ctx] () {
             int num_tasks = tasks.size();
             int n_bad_local = 0;
             while (true) {
@@ -1457,7 +1450,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
                 }
                 int last = std::min(first + K_TOKEN_CHUNK, num_tasks);
                 for (int i = first; i < last; ++i) {
-                    if (!multiple_choice_prepare_one_task(ctx, add_bos, tasks[i], false)) ++n_bad_local;
+                    if (!multiple_choice_prepare_one_task(ctx, tasks[i], false)) ++n_bad_local;
                 }
             }
         };
@@ -1479,7 +1472,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
         int i_task = 0;
         for (auto& task : tasks) {
             ++i_task;
-            if (!multiple_choice_prepare_one_task(ctx, add_bos, task, true)) {
+            if (!multiple_choice_prepare_one_task(ctx, task, true)) {
                 return;
             }
             if (i_task%n_dot == 0) {
@@ -1715,6 +1708,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
     const int num_batches = (n_ctx + n_batch - 1)/n_batch;
     const int nv = 2*((n_vocab + 1)/2) + 4;
     const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
+    GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1);
 
     std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
     std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
index 6c64fe3e17dec9111f7318a4fdad1ac6b9ab7be1..2e791190b740a05dd316fb1b74c6ca24a78263b6 100644 (file)
@@ -689,6 +689,7 @@ struct server_context {
         n_ctx = llama_n_ctx(ctx);
 
         add_bos_token = llama_should_add_bos_token(model);
+        GGML_ASSERT(llama_add_eos_token(model) != 1);
 
         return true;
     }
@@ -758,7 +759,7 @@ struct server_context {
         metrics.init();
     }
 
-    std::vector<llama_token> tokenize(const json & json_prompt, bool add_bos) const {
+    std::vector<llama_token> tokenize(const json & json_prompt, bool add_special) const {
         // TODO: currently, we tokenize using special tokens by default
         //       this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
         //       but it's better compared to completely ignoring ChatML and other chat templates
@@ -776,7 +777,7 @@ struct server_context {
 
                     std::vector<llama_token> p;
                     if (first) {
-                        p = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL);
+                        p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
                         first = false;
                     } else {
                         p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
@@ -793,7 +794,7 @@ struct server_context {
             }
         } else {
             auto s = json_prompt.template get<std::string>();
-            prompt_tokens = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL);
+            prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
         }
 
         return prompt_tokens;
@@ -1058,7 +1059,7 @@ struct server_context {
         system_tokens.clear();
 
         if (!system_prompt.empty()) {
-            system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
+            system_tokens = ::llama_tokenize(ctx, system_prompt, true);
 
             llama_batch_clear(batch);
 
@@ -1914,7 +1915,7 @@ struct server_context {
                             prefix_tokens.push_back(llama_token_middle(model));
                             prompt_tokens = prefix_tokens;
                         } else {
-                            prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token);  // add BOS if there isn't system prompt
+                            prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
                         }
 
                         slot.n_past = 0;
index 6e0815b3699868e82c12ed16fd64a0741bd23d72..6a7367b0cde6b0b699fd9a807acf12ff09c84ecc 100644 (file)
@@ -76,6 +76,28 @@ int main(int argc, char ** argv) {
     params.n_threads_batch = params.n_threads_batch_draft;
     std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
 
+    const bool vocab_type_tgt = llama_vocab_type(model_tgt);
+    LOG("vocab_type tgt: %d\n", vocab_type_tgt);
+
+    const bool vocab_type_dft = llama_vocab_type(model_dft);
+    LOG("vocab_type dft: %d\n", vocab_type_dft);
+
+    if (vocab_type_tgt != vocab_type_dft) {
+        fprintf(stderr, "%s: error: draft model vocab type must match target model to use speculation but ", __func__);
+        fprintf(stderr, "vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
+        return 1;
+    }
+
+    if (
+        llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
+        llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
+        llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
+        llama_token_eos(model_tgt) != llama_token_eos(model_dft)
+    ) {
+        fprintf(stderr, "%s: error: draft model special tokens must match target model to use speculation\n", __func__);
+        return 1;
+    }
+
     {
         const int n_vocab_tgt = llama_n_vocab(model_tgt);
         const int n_vocab_dft = llama_n_vocab(model_dft);
@@ -105,20 +127,8 @@ int main(int argc, char ** argv) {
 
 
     // Tokenize the prompt
-    const bool add_bos_tgt = llama_should_add_bos_token(model_tgt);
-    LOG("add_bos tgt: %d\n", add_bos_tgt);
-
-    const bool add_bos_dft = llama_should_add_bos_token(model_dft);
-    LOG("add_bos dft: %d\n", add_bos_dft);
-
-    if (add_bos_tgt != add_bos_dft) {
-        fprintf(stderr, "%s: error: draft model add_bos must match target model to use speculation but ", __func__);
-        fprintf(stderr, "add_bos_dft = %d while add_bos_tgt = %d\n", add_bos_dft, add_bos_tgt);
-        return 1;
-    }
-
     std::vector<llama_token> inp;
-    inp = ::llama_tokenize(ctx_tgt, params.prompt, add_bos_tgt, true);
+    inp = ::llama_tokenize(ctx_tgt, params.prompt, true, true);
 
     const int max_context_size     = llama_n_ctx(ctx_tgt);
     const int max_tokens_list_size = max_context_size - 4;
index d95a9247525ebd90a240970f224497d63e951c76..8b1baea800cc86a1d5a5d901582c736f6c0c811a 100644 (file)
@@ -26,11 +26,9 @@ int main(int argc, char ** argv) {
     llama_context_params ctx_params = llama_context_default_params();
     llama_context * ctx = llama_new_context_with_model(model, ctx_params);
 
-    const bool add_bos = llama_should_add_bos_token(model);
-
     std::vector<llama_token> tokens;
 
-    tokens = ::llama_tokenize(model, prompt, add_bos, true);
+    tokens = ::llama_tokenize(model, prompt, true, true);
 
     for (int i = 0; i < (int) tokens.size(); i++) {
         if (printing_ids) {
index 6a090d1bbc24c975b506493aed05e3e82b524b08..8dbf474865038a61c93ff8b94e6025cdd71d4b6c 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -318,6 +318,8 @@ enum llm_kv {
     LLM_KV_TOKENIZER_UNK_ID,
     LLM_KV_TOKENIZER_SEP_ID,
     LLM_KV_TOKENIZER_PAD_ID,
+    LLM_KV_TOKENIZER_CLS_ID,
+    LLM_KV_TOKENIZER_MASK_ID,
     LLM_KV_TOKENIZER_ADD_BOS,
     LLM_KV_TOKENIZER_ADD_EOS,
     LLM_KV_TOKENIZER_ADD_PREFIX,
@@ -388,6 +390,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_TOKENIZER_UNK_ID,              "tokenizer.ggml.unknown_token_id"   },
     { LLM_KV_TOKENIZER_SEP_ID,              "tokenizer.ggml.seperator_token_id" },
     { LLM_KV_TOKENIZER_PAD_ID,              "tokenizer.ggml.padding_token_id"   },
+    { LLM_KV_TOKENIZER_CLS_ID,              "tokenizer.ggml.cls_token_id"       },
+    { LLM_KV_TOKENIZER_MASK_ID,             "tokenizer.ggml.mask_token_id"      },
     { LLM_KV_TOKENIZER_ADD_BOS,             "tokenizer.ggml.add_bos_token"      },
     { LLM_KV_TOKENIZER_ADD_EOS,             "tokenizer.ggml.add_eos_token"      },
     { LLM_KV_TOKENIZER_ADD_PREFIX,          "tokenizer.ggml.add_space_prefix"   },
@@ -2018,11 +2022,13 @@ struct llama_vocab {
     std::map<std::pair<std::string, std::string>, int> bpe_ranks;
 
     // default LLaMA special tokens
-    id special_bos_id = 1;
-    id special_eos_id = 2;
-    id special_unk_id = 0;
-    id special_sep_id = -1;
-    id special_pad_id = -1;
+    id special_bos_id  = 1;
+    id special_eos_id  = 2;
+    id special_unk_id  = 0;
+    id special_sep_id  = -1;
+    id special_pad_id  = -1;
+    id special_cls_id  = -1;
+    id special_mask_id = -1;
 
     int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add.
     int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
@@ -3978,7 +3984,9 @@ static void llm_load_hparams(
 }
 
 // TODO: This should probably be in llama.h
-static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special = false);
+static std::vector<llama_vocab::id> llama_tokenize_internal(
+    const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special = false
+);
 static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch);
 
 static void llm_load_vocab(
@@ -4000,23 +4008,27 @@ static void llm_load_vocab(
             vocab.type = LLAMA_VOCAB_TYPE_NONE;
 
             // default special tokens
-            vocab.special_bos_id = -1;
-            vocab.special_eos_id = -1;
-            vocab.special_unk_id = -1;
-            vocab.special_sep_id = -1;
-            vocab.special_pad_id = -1;
-            vocab.linefeed_id    = -1;
+            vocab.special_bos_id  = -1;
+            vocab.special_eos_id  = -1;
+            vocab.special_unk_id  = -1;
+            vocab.special_sep_id  = -1;
+            vocab.special_pad_id  = -1;
+            vocab.special_cls_id  = -1;
+            vocab.special_mask_id = -1;
+            vocab.linefeed_id     = -1;
 
             return;
         } else if (tokenizer_name == "llama") {
             vocab.type = LLAMA_VOCAB_TYPE_SPM;
 
             // default special tokens
-            vocab.special_bos_id = 1;
-            vocab.special_eos_id = 2;
-            vocab.special_unk_id = 0;
-            vocab.special_sep_id = -1;
-            vocab.special_pad_id = -1;
+            vocab.special_bos_id  = 1;
+            vocab.special_eos_id  = 2;
+            vocab.special_unk_id  = 0;
+            vocab.special_sep_id  = -1;
+            vocab.special_pad_id  = -1;
+            vocab.special_cls_id  = -1;
+            vocab.special_mask_id = -1;
 
             const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
             if (add_space_prefix_keyidx != -1) {
@@ -4051,20 +4063,24 @@ static void llm_load_vocab(
             }
 
             // default special tokens
-            vocab.special_bos_id = 11;
-            vocab.special_eos_id = 11;
-            vocab.special_unk_id = -1;
-            vocab.special_sep_id = -1;
-            vocab.special_pad_id = -1;
+            vocab.special_bos_id  = 11;
+            vocab.special_eos_id  = 11;
+            vocab.special_unk_id  = -1;
+            vocab.special_sep_id  = -1;
+            vocab.special_pad_id  = -1;
+            vocab.special_cls_id  = -1;
+            vocab.special_mask_id = -1;
         } else if (tokenizer_name == "bert") {
             vocab.type = LLAMA_VOCAB_TYPE_WPM;
 
             // default special tokens
-            vocab.special_bos_id = 101;
-            vocab.special_eos_id = 102;
-            vocab.special_unk_id = 100;
-            vocab.special_sep_id = -1;
-            vocab.special_pad_id = -1;
+            vocab.special_bos_id  = -1;
+            vocab.special_eos_id  = -1;
+            vocab.special_unk_id  = 100;
+            vocab.special_sep_id  = 102;
+            vocab.special_pad_id  = 0;
+            vocab.special_cls_id  = 101;
+            vocab.special_mask_id = 103;
             vocab.add_space_prefix = false;
         } else {
             LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
@@ -4127,11 +4143,13 @@ static void llm_load_vocab(
     // special tokens
     {
         const std::vector<std::pair<enum llm_kv, int32_t &>> special_token_types = {
-            { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id },
-            { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id },
-            { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id },
-            { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id },
-            { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id },
+            { LLM_KV_TOKENIZER_BOS_ID,  vocab.special_bos_id  },
+            { LLM_KV_TOKENIZER_EOS_ID,  vocab.special_eos_id  },
+            { LLM_KV_TOKENIZER_UNK_ID,  vocab.special_unk_id  },
+            { LLM_KV_TOKENIZER_SEP_ID,  vocab.special_sep_id  },
+            { LLM_KV_TOKENIZER_PAD_ID,  vocab.special_pad_id  },
+            { LLM_KV_TOKENIZER_CLS_ID,  vocab.special_cls_id  },
+            { LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id },
         };
         for (const auto & it : special_token_types) {
             const std::string & key = kv(std::get<0>(it));
@@ -4323,12 +4341,14 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     LLAMA_LOG_INFO("%s: general.name     = %s\n",    __func__, model.name.c_str());
 
     // special tokens
-    if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); }
-    if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); }
-    if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); }
-    if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); }
-    if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
-    if (vocab.linefeed_id    != -1) { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, vocab.linefeed_id,    vocab.id_to_token[vocab.linefeed_id].text.c_str() );    }
+    if (vocab.special_bos_id  != -1) { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, vocab.special_bos_id,  vocab.id_to_token[vocab.special_bos_id].text.c_str() );  }
+    if (vocab.special_eos_id  != -1) { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, vocab.special_eos_id,  vocab.id_to_token[vocab.special_eos_id].text.c_str() );  }
+    if (vocab.special_unk_id  != -1) { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, vocab.special_unk_id,  vocab.id_to_token[vocab.special_unk_id].text.c_str() );  }
+    if (vocab.special_sep_id  != -1) { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, vocab.special_sep_id,  vocab.id_to_token[vocab.special_sep_id].text.c_str() );  }
+    if (vocab.special_pad_id  != -1) { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, vocab.special_pad_id,  vocab.id_to_token[vocab.special_pad_id].text.c_str() );  }
+    if (vocab.special_cls_id  != -1) { LLAMA_LOG_INFO( "%s: CLS token        = %d '%s'\n", __func__, vocab.special_cls_id,  vocab.id_to_token[vocab.special_cls_id].text.c_str() );  }
+    if (vocab.special_mask_id != -1) { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
+    if (vocab.linefeed_id     != -1) { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, vocab.linefeed_id,     vocab.id_to_token[vocab.linefeed_id].text.c_str() );     }
 }
 
 // Returns false if cancelled by progress_callback
@@ -11358,9 +11378,6 @@ struct llm_tokenizer_wpm {
                 output.push_back(vocab.special_unk_id);
             }
         }
-
-        // append eos token
-        output.push_back(vocab.special_eos_id);
     }
 
     std::vector<std::string> preprocess(const std::string & text) {
@@ -11565,30 +11582,28 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
     }
 }
 
-static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special) {
+static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) {
     std::vector<llama_vocab::id> output;
-
-    // OG tokenizer behavior:
-    //
-    // tokenizer.encode('', add_bos=True)  returns [1]
-    // tokenizer.encode('', add_bos=False) returns []
-
-    if (bos && vocab.special_bos_id != -1) {
-        output.push_back(vocab.special_bos_id);
-    }
-
-    if (raw_text.empty()) {
-        return output;
-    }
-
     std::forward_list<fragment_buffer_variant> fragment_buffer;
-    fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
 
-    if (special) tokenizer_st_partition(vocab, fragment_buffer);
+    if (!raw_text.empty()) {
+        fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
+        if (parse_special) tokenizer_st_partition(vocab, fragment_buffer);
+    }
 
     switch (vocab.type) {
         case LLAMA_VOCAB_TYPE_SPM:
             {
+                // OG tokenizer behavior:
+                //
+                // tokenizer.encode('', add_special_tokens=True)  returns [1]
+                // tokenizer.encode('', add_special_tokens=False) returns []
+
+                if (add_special && vocab.special_add_bos != 0) {
+                    GGML_ASSERT(vocab.special_bos_id != -1);
+                    output.push_back(vocab.special_bos_id);
+                }
+
                 for (const auto & fragment : fragment_buffer) {
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
                         // without adding this leading whitespace, we do not get the same results as the original tokenizer
@@ -11614,9 +11629,19 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
                         output.push_back(fragment.token);
                     }
                 }
+
+                if (add_special && vocab.special_add_eos == 1) {
+                    GGML_ASSERT(vocab.special_eos_id != -1);
+                    output.push_back(vocab.special_eos_id);
+                }
             } break;
         case LLAMA_VOCAB_TYPE_BPE:
             {
+                if (add_special && vocab.special_add_bos == 1) {
+                    GGML_ASSERT(vocab.special_bos_id != -1);
+                    output.push_back(vocab.special_bos_id);
+                }
+
                 for (const auto & fragment : fragment_buffer) {
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
                         auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
@@ -11630,9 +11655,16 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
                         output.push_back(fragment.token);
                     }
                 }
+
+                GGML_ASSERT(vocab.special_add_eos != 1);
             } break;
         case LLAMA_VOCAB_TYPE_WPM:
             {
+                if (add_special) {
+                    GGML_ASSERT(vocab.special_cls_id != -1);
+                    output.push_back(vocab.special_cls_id);
+                }
+
                 for (const auto & fragment : fragment_buffer) {
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
                         auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
@@ -11646,6 +11678,11 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
                         output.push_back(fragment.token);
                     }
                 }
+
+                if (add_special) {
+                    GGML_ASSERT(vocab.special_sep_id != -1);
+                    output.push_back(vocab.special_sep_id);
+                }
             } break;
         case LLAMA_VOCAB_TYPE_NONE:
             GGML_ASSERT(false);
@@ -16104,6 +16141,14 @@ llama_token llama_token_eos(const struct llama_model * model) {
     return model->vocab.special_eos_id;
 }
 
+llama_token llama_token_cls(const struct llama_model * model) {
+    return model->vocab.special_cls_id;
+}
+
+llama_token llama_token_sep(const struct llama_model * model) {
+    return model->vocab.special_sep_id;
+}
+
 llama_token llama_token_nl(const struct llama_model * model) {
     return model->vocab.linefeed_id;
 }
@@ -16138,9 +16183,9 @@ int32_t llama_tokenize(
                      int32_t   text_len,
                  llama_token * tokens,
                      int32_t   n_tokens_max,
-                        bool   add_bos,
-                        bool   special) {
-    auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special);
+                        bool   add_special,
+                        bool   parse_special) {
+    auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special);
 
     if (n_tokens_max < (int) res.size()) {
         // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
diff --git a/llama.h b/llama.h
index 6a5bbe26d4de7848156da3f95819173f713da9a9..b770a275ff02fbf4ce664a4194a1a3b66478ad6c 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -786,6 +786,8 @@ extern "C" {
     // Special tokens
     LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
     LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
+    LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
+    LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
     LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
 
     // Returns -1 if unknown, 1 for true or 0 for false.
@@ -808,16 +810,16 @@ extern "C" {
     /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
     /// @return Returns the number of tokens on success, no more than n_tokens_max
     /// @return Returns a negative number on failure - the number of tokens that would have been returned
-    /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
-    ///                Does not insert a leading space.
+    /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
+    ///                      as plaintext. Does not insert a leading space.
     LLAMA_API int32_t llama_tokenize(
         const struct llama_model * model,
                       const char * text,
                          int32_t   text_len,
                      llama_token * tokens,
                          int32_t   n_tokens_max,
-                            bool   add_bos,
-                            bool   special);
+                            bool   add_special,
+                            bool   parse_special);
 
     // Token Id -> Piece.
     // Uses the vocabulary in the provided context.