std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx,
const std::string & text,
- bool add_bos,
- bool special) {
- return llama_tokenize(llama_get_model(ctx), text, add_bos, special);
+ bool add_special,
+ bool parse_special) {
+ return llama_tokenize(llama_get_model(ctx), text, add_special, parse_special);
}
std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
- bool add_bos,
- bool special) {
+ bool add_special,
+ bool parse_special) {
// upper limit for the number of tokens
- int n_tokens = text.length() + add_bos;
+ int n_tokens = text.length() + 2 * add_special;
std::vector<llama_token> result(n_tokens);
- n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
+ n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
if (n_tokens < 0) {
result.resize(-n_tokens);
- int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
+ int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx,
const std::string & text,
- bool add_bos,
- bool special = false);
+ bool add_special,
+ bool parse_special = false);
std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
- bool add_bos,
- bool special = false);
+ bool add_special,
+ bool parse_special = false);
// tokenizes a token into a piece
// should work similar to Python's `tokenizer.id_to_piece`
return ("pytorch_model.bin",)
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
- def _set_vocab_gpt2(self):
- dir_model = self.dir_model
- hparams = self.hparams
+ # used for GPT-2 BPE and WordPiece vocabs
+ def get_basic_vocab(self) -> tuple[list[str], list[int]]:
tokens: list[str] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
- tokenizer = AutoTokenizer.from_pretrained(dir_model)
- vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
+ tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
+ vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)
+ return tokens, toktypes
+
+ def _set_vocab_gpt2(self) -> None:
+ tokens, toktypes = self.get_basic_vocab()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
- special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
+ special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_vocab.add_to_gguf(self.gguf_writer)
def _set_vocab_qwen(self):
self.gguf_writer.add_pooling_type(pooling_type)
def set_vocab(self):
- # use huggingface vocab to get all tokens
- vocab = LlamaHfVocab(self.dir_model, ignore_nonllama=True)
- tokens, scores, toktypes = zip(*vocab.all_tokens())
- assert len(tokens) == vocab.vocab_size
- self.vocab_size = vocab.vocab_size
+ tokens, toktypes = self.get_basic_vocab()
+ self.vocab_size = len(tokens)
# we need this to validate the size of the token_type embeddings
# though currently we are passing all zeros to the token_type embeddings
- n_token_types = len(set(toktypes))
- self.gguf_writer.add_token_type_count(n_token_types)
+ self.gguf_writer.add_token_type_count(2) # "Sequence A" or "Sequence B"
# convert to phantom space vocab
- def phantom(tok, typ):
- if tok.startswith(b"[") and tok.endswith(b"]"):
+ def phantom(tok):
+ if tok.startswith("[") and tok.endswith("]"):
return tok
- if tok.startswith(b"##"):
+ if tok.startswith("##"):
return tok[2:]
- return b"\xe2\x96\x81" + tok
- tokens = tuple(phantom(t, y) for t, y in zip(tokens, toktypes))
-
- # set up bos and eos tokens (cls and sep)
- self.gguf_writer.add_bos_token_id(vocab.tokenizer.cls_token_id)
- self.gguf_writer.add_eos_token_id(vocab.tokenizer.sep_token_id)
+ return "\u2581" + tok
+ tokens = list(map(phantom, tokens))
# add vocab to gguf
self.gguf_writer.add_tokenizer_model("bert")
self.gguf_writer.add_token_list(tokens)
- self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
# handle special tokens
super().set_gguf_parameters()
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
- def get_tensors(self):
- assert self.vocab_size is not None
- for name, data in super().get_tensors():
- # Nomic Embed's token embeddings tensor is padded, but llama.cpp wants tensor sizes to match exactly.
- if name == 'embeddings.word_embeddings.weight' and data.shape[1] != self.vocab_size:
- rounded_vocab_size = (self.vocab_size + 63) // 64 * 64
- assert data.shape == (rounded_vocab_size, self.hparams["n_embd"])
- data = data[:self.vocab_size, :]
- yield name, data
-
@Model.register("GemmaForCausalLM")
class GemmaModel(Model):
data = data.astype(np.float32)
# if f16 desired, convert big float32 2-dim weight tensors to float16
- if self.ftype == 1 and data_dtype == np.float32 and new_name.removesuffix(".weight").endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
+ new_weight_name = new_name[:-len(".weight")] if new_name.endswith(".weight") else ""
+ if self.ftype == 1 and data_dtype == np.float32 and new_weight_name.endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
data = data.astype(np.float16)
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
#!/usr/bin/env python3
+from __future__ import annotations
+
import argparse
import os
import sys
import gguf
if TYPE_CHECKING:
- from typing import TypeAlias
+ from typing_extensions import Self, TypeAlias
if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'):
faulthandler.register(signal.SIGUSR1)
tokenizer_model = "llama"
name = "hfft"
- def __init__(self, base_path: Path, ignore_nonllama: bool = False):
+ def __init__(self, base_path: Path):
fname_tokenizer = base_path / FAST_TOKENIZER_FILE
# if this fails, FileNotFoundError propagates to caller
with open(fname_tokenizer, encoding='utf-8') as f:
# pre-check so we know if we need transformers
tokenizer_model: dict[str, Any] = tokenizer_json['model']
- if ignore_nonllama:
- pass # workaround incorrect use of this class for WordPiece
- elif (
+ if (
tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
or tokenizer_json['decoder']['type'] != 'Sequence'
):
class Tensor(ABC):
+ ndarray: NDArray
data_type: DataType
@abstractmethod
- def astype(self, data_type: DataType) -> Tensor: ...
+ def astype(self, data_type: DataType) -> Self: ...
@abstractmethod
- def permute(self, n_head: int, n_head_kv: int) -> Tensor: ...
+ def permute(self, n_head: int, n_head_kv: int) -> Self: ...
@abstractmethod
- def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: ...
+ def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> Self: ...
@abstractmethod
- def part(self, n_part: int) -> UnquantizedTensor: ...
+ def part(self, n_part: int) -> Self: ...
@abstractmethod
def to_ggml(self) -> GGMLCompatibleTensor: ...
self.ndarray = ndarray
self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype]
- def astype(self, data_type: DataType) -> Tensor:
+ def astype(self, data_type: DataType) -> UnquantizedTensor:
dtype = data_type.dtype
if self.data_type == DT_BF16:
self.ndarray = bf16_to_fp32(self.ndarray)
return UnquantizedTensor(self.ndarray.astype(dtype))
- def to_ggml(self) -> UnquantizedTensor:
+ def to_ggml(self) -> Self:
return self
def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor:
inputs.push_back(inp);
}
- // add eos if not present
+ // add SEP if not present
for (auto & inp : inputs) {
- if (inp.empty() || inp.back() != llama_token_eos(model)) {
- inp.push_back(llama_token_eos(model));
+ if (inp.empty() || inp.back() != llama_token_sep(model)) {
+ inp.push_back(llama_token_sep(model));
}
}
static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl, int from_chunk) {
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
+ GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1);
const int n_ctx = llama_n_ctx(ctx);
auto tim1 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
- std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
+ std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true);
auto tim2 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
LOG_TEE("%s\n", get_system_info(params).c_str());
}
const bool add_bos = llama_should_add_bos_token(model);
+ GGML_ASSERT(llama_add_eos_token(model) != 1);
LOG("add_bos: %d\n", add_bos);
bool suff_rm_leading_spc = params.escape;
if (ctx_guidance) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
- guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos);
+ guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
- std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
+ std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
original_prompt_len = original_inp.size();
int n_past = 0;
const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
- const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx_llava->ctx_llama));
std::string system_prompt, user_prompt;
size_t image_pos = prompt.find("<image>");
}
}
- eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, add_bos);
+ eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, true);
llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past);
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
std::tie(model, ctx) = llama_init_from_gpt_params(params);
// Tokenize the prompt
- const bool add_bos = llama_should_add_bos_token(model);
- LOG("add_bos tgt: %d\n", add_bos);
-
std::vector<llama_token> inp;
std::vector<llama_token> all;
- inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
+ inp = ::llama_tokenize(ctx, params.prompt, true, true);
all = inp;
const int max_context_size = llama_n_ctx(ctx);
GGML_ASSERT(model != nullptr);
// tokenize the prompt
- const bool add_bos = llama_should_add_bos_token(model);
-
std::vector<llama_token> inp;
- inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
+ inp = ::llama_tokenize(ctx, params.prompt, true, true);
fprintf(stderr, "%s: tokenization done\n", __func__);
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
// tokenize the prompt
- const bool add_bos = llama_should_add_bos_token(model);
- LOG("add_bos tgt: %d\n", add_bos);
-
std::vector<llama_token> inp;
- inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
+ inp = ::llama_tokenize(ctx, params.prompt, true, true);
llama_ngram_cache ngram_cache_context;
llama_ngram_cache ngram_cache_dynamic;
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
// tokenize the prompt
- const bool add_bos = llama_should_add_bos_token(model);
- LOG("add_bos tgt: %d\n", add_bos);
-
std::vector<llama_token> inp;
- inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
+ inp = ::llama_tokenize(ctx, params.prompt, true, true);
llama_ngram_cache ngram_cache_context;
llama_ngram_cache ngram_cache_dynamic;
}
const bool add_bos = llama_should_add_bos_token(model);
+ GGML_ASSERT(llama_add_eos_token(model) != 1);
LOG("add_bos: %d\n", add_bos);
std::vector<llama_token> embd_inp;
if (params.chatml) {
params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>";
}
- embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
+ embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
} else {
LOG("use session tokens\n");
embd_inp = session_tokens;
if (ctx_guidance) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
- guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos, true);
+ guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true, true);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
- std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
+ std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
original_prompt_len = original_inp.size();
}
// prefix & suffix for instruct mode
- const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true);
- const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true);
+ const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true, true);
+ const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true);
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());
// chatml prefix & suffix
- const auto cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", add_bos, true);
+ const auto cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", true, true);
const auto cml_sfx = ::llama_tokenize(ctx, "<|im_end|>\n<|im_start|>assistant\n", false, true);
LOG("cml_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_pfx).c_str());
// BOS tokens will be added for each chunk before eval
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
+ GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1);
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
- std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
+ std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true);
const int n_ctx = llama_n_ctx(ctx);
// BOS tokens will be added for each chunk before eval
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
+ GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1);
std::ofstream logits_stream;
if (!params.logits_file.empty()) {
auto tim1 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
- std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
+ std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, true);
auto tim2 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
fprintf(stderr, "================================= is_spm = %d\n", is_spm);
- // This is needed as usual for LLaMA models
- const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
-
// The tasks should be randomized so the score stabilizes quickly.
bool randomize_tasks = true;
hs_cur.gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
for (size_t j = 0; j < 4; j++) {
hs_cur.ending[j] = prompt_lines[idx*6+2+j];
- hs_cur.seq_tokens[j] = ::llama_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], add_bos);
+ hs_cur.seq_tokens[j] = ::llama_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], true);
}
// determine the common prefix of the endings
hs_cur.seq_tokens[2].size() - hs_cur.common_prefix +
hs_cur.seq_tokens[3].size() - hs_cur.common_prefix;
- //GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, add_bos).size());
+ //GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, true).size());
// Delete the selected random example from the prompt
if (randomize_tasks) {
fprintf(stderr, "%s : tokenizing selected tasks\n", __func__);
- // This is needed as usual for LLaMA models
- const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
-
for (auto & task : data) {
- task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, add_bos);
- task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, add_bos);
+ task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, true);
+ task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, true);
task.common_prefix = 0;
for (size_t k = 0; k < task.seq_tokens[0].size(); k++) {
task.seq_tokens[0].size() - task.common_prefix +
task.seq_tokens[1].size() - task.common_prefix;
- task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos).size();
- task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos).size();
+ task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], true).size();
+ task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], true).size();
}
fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__);
std::vector<float> log_probs;
};
-static bool multiple_choice_prepare_one_task(llama_context * ctx, bool add_bos, multiple_choice_task& task, bool log_error) {
+static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choice_task& task, bool log_error) {
if (task.question.empty() || task.mc1.answers.empty()) {
if (log_error) {
printf("%s: found bad task with empty question and/or answers\n", __func__);
}
return false;
}
- task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, add_bos));
+ task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, true));
}
auto min_len = task.seq_tokens.front().size();
for (auto& seq : task.seq_tokens) {
n_task = params.multiple_choice_tasks;
}
- // This is needed as usual for LLaMA models
- const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
-
printf("%s: preparing task data", __func__);
fflush(stdout);
if (n_task > 500) {
fflush(stdout);
std::atomic<int> counter(0);
std::atomic<int> n_bad(0);
- auto prepare = [&counter, &n_bad, &tasks, ctx, add_bos] () {
+ auto prepare = [&counter, &n_bad, &tasks, ctx] () {
int num_tasks = tasks.size();
int n_bad_local = 0;
while (true) {
}
int last = std::min(first + K_TOKEN_CHUNK, num_tasks);
for (int i = first; i < last; ++i) {
- if (!multiple_choice_prepare_one_task(ctx, add_bos, tasks[i], false)) ++n_bad_local;
+ if (!multiple_choice_prepare_one_task(ctx, tasks[i], false)) ++n_bad_local;
}
}
};
int i_task = 0;
for (auto& task : tasks) {
++i_task;
- if (!multiple_choice_prepare_one_task(ctx, add_bos, task, true)) {
+ if (!multiple_choice_prepare_one_task(ctx, task, true)) {
return;
}
if (i_task%n_dot == 0) {
const int num_batches = (n_ctx + n_batch - 1)/n_batch;
const int nv = 2*((n_vocab + 1)/2) + 4;
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
+ GGML_ASSERT(llama_add_eos_token(llama_get_model(ctx)) != 1);
std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
n_ctx = llama_n_ctx(ctx);
add_bos_token = llama_should_add_bos_token(model);
+ GGML_ASSERT(llama_add_eos_token(model) != 1);
return true;
}
metrics.init();
}
- std::vector<llama_token> tokenize(const json & json_prompt, bool add_bos) const {
+ std::vector<llama_token> tokenize(const json & json_prompt, bool add_special) const {
// TODO: currently, we tokenize using special tokens by default
// this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
// but it's better compared to completely ignoring ChatML and other chat templates
std::vector<llama_token> p;
if (first) {
- p = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL);
+ p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
first = false;
} else {
p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
}
} else {
auto s = json_prompt.template get<std::string>();
- prompt_tokens = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL);
+ prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
}
return prompt_tokens;
system_tokens.clear();
if (!system_prompt.empty()) {
- system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
+ system_tokens = ::llama_tokenize(ctx, system_prompt, true);
llama_batch_clear(batch);
prefix_tokens.push_back(llama_token_middle(model));
prompt_tokens = prefix_tokens;
} else {
- prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
+ prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
}
slot.n_past = 0;
params.n_threads_batch = params.n_threads_batch_draft;
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
+ const bool vocab_type_tgt = llama_vocab_type(model_tgt);
+ LOG("vocab_type tgt: %d\n", vocab_type_tgt);
+
+ const bool vocab_type_dft = llama_vocab_type(model_dft);
+ LOG("vocab_type dft: %d\n", vocab_type_dft);
+
+ if (vocab_type_tgt != vocab_type_dft) {
+ fprintf(stderr, "%s: error: draft model vocab type must match target model to use speculation but ", __func__);
+ fprintf(stderr, "vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
+ return 1;
+ }
+
+ if (
+ llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
+ llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
+ llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
+ llama_token_eos(model_tgt) != llama_token_eos(model_dft)
+ ) {
+ fprintf(stderr, "%s: error: draft model special tokens must match target model to use speculation\n", __func__);
+ return 1;
+ }
+
{
const int n_vocab_tgt = llama_n_vocab(model_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft);
// Tokenize the prompt
- const bool add_bos_tgt = llama_should_add_bos_token(model_tgt);
- LOG("add_bos tgt: %d\n", add_bos_tgt);
-
- const bool add_bos_dft = llama_should_add_bos_token(model_dft);
- LOG("add_bos dft: %d\n", add_bos_dft);
-
- if (add_bos_tgt != add_bos_dft) {
- fprintf(stderr, "%s: error: draft model add_bos must match target model to use speculation but ", __func__);
- fprintf(stderr, "add_bos_dft = %d while add_bos_tgt = %d\n", add_bos_dft, add_bos_tgt);
- return 1;
- }
-
std::vector<llama_token> inp;
- inp = ::llama_tokenize(ctx_tgt, params.prompt, add_bos_tgt, true);
+ inp = ::llama_tokenize(ctx_tgt, params.prompt, true, true);
const int max_context_size = llama_n_ctx(ctx_tgt);
const int max_tokens_list_size = max_context_size - 4;
llama_context_params ctx_params = llama_context_default_params();
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
- const bool add_bos = llama_should_add_bos_token(model);
-
std::vector<llama_token> tokens;
- tokens = ::llama_tokenize(model, prompt, add_bos, true);
+ tokens = ::llama_tokenize(model, prompt, true, true);
for (int i = 0; i < (int) tokens.size(); i++) {
if (printing_ids) {
LLM_KV_TOKENIZER_UNK_ID,
LLM_KV_TOKENIZER_SEP_ID,
LLM_KV_TOKENIZER_PAD_ID,
+ LLM_KV_TOKENIZER_CLS_ID,
+ LLM_KV_TOKENIZER_MASK_ID,
LLM_KV_TOKENIZER_ADD_BOS,
LLM_KV_TOKENIZER_ADD_EOS,
LLM_KV_TOKENIZER_ADD_PREFIX,
{ LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" },
{ LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" },
{ LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" },
+ { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" },
+ { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
// default LLaMA special tokens
- id special_bos_id = 1;
- id special_eos_id = 2;
- id special_unk_id = 0;
- id special_sep_id = -1;
- id special_pad_id = -1;
+ id special_bos_id = 1;
+ id special_eos_id = 2;
+ id special_unk_id = 0;
+ id special_sep_id = -1;
+ id special_pad_id = -1;
+ id special_cls_id = -1;
+ id special_mask_id = -1;
int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add.
int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
}
// TODO: This should probably be in llama.h
-static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special = false);
+static std::vector<llama_vocab::id> llama_tokenize_internal(
+ const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special = false
+);
static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch);
static void llm_load_vocab(
vocab.type = LLAMA_VOCAB_TYPE_NONE;
// default special tokens
- vocab.special_bos_id = -1;
- vocab.special_eos_id = -1;
- vocab.special_unk_id = -1;
- vocab.special_sep_id = -1;
- vocab.special_pad_id = -1;
- vocab.linefeed_id = -1;
+ vocab.special_bos_id = -1;
+ vocab.special_eos_id = -1;
+ vocab.special_unk_id = -1;
+ vocab.special_sep_id = -1;
+ vocab.special_pad_id = -1;
+ vocab.special_cls_id = -1;
+ vocab.special_mask_id = -1;
+ vocab.linefeed_id = -1;
return;
} else if (tokenizer_name == "llama") {
vocab.type = LLAMA_VOCAB_TYPE_SPM;
// default special tokens
- vocab.special_bos_id = 1;
- vocab.special_eos_id = 2;
- vocab.special_unk_id = 0;
- vocab.special_sep_id = -1;
- vocab.special_pad_id = -1;
+ vocab.special_bos_id = 1;
+ vocab.special_eos_id = 2;
+ vocab.special_unk_id = 0;
+ vocab.special_sep_id = -1;
+ vocab.special_pad_id = -1;
+ vocab.special_cls_id = -1;
+ vocab.special_mask_id = -1;
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
if (add_space_prefix_keyidx != -1) {
}
// default special tokens
- vocab.special_bos_id = 11;
- vocab.special_eos_id = 11;
- vocab.special_unk_id = -1;
- vocab.special_sep_id = -1;
- vocab.special_pad_id = -1;
+ vocab.special_bos_id = 11;
+ vocab.special_eos_id = 11;
+ vocab.special_unk_id = -1;
+ vocab.special_sep_id = -1;
+ vocab.special_pad_id = -1;
+ vocab.special_cls_id = -1;
+ vocab.special_mask_id = -1;
} else if (tokenizer_name == "bert") {
vocab.type = LLAMA_VOCAB_TYPE_WPM;
// default special tokens
- vocab.special_bos_id = 101;
- vocab.special_eos_id = 102;
- vocab.special_unk_id = 100;
- vocab.special_sep_id = -1;
- vocab.special_pad_id = -1;
+ vocab.special_bos_id = -1;
+ vocab.special_eos_id = -1;
+ vocab.special_unk_id = 100;
+ vocab.special_sep_id = 102;
+ vocab.special_pad_id = 0;
+ vocab.special_cls_id = 101;
+ vocab.special_mask_id = 103;
vocab.add_space_prefix = false;
} else {
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
// special tokens
{
const std::vector<std::pair<enum llm_kv, int32_t &>> special_token_types = {
- { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id },
- { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id },
- { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id },
- { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id },
- { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id },
+ { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id },
+ { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id },
+ { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id },
+ { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id },
+ { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id },
+ { LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id },
+ { LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id },
};
for (const auto & it : special_token_types) {
const std::string & key = kv(std::get<0>(it));
LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str());
// special tokens
- if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); }
- if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); }
- if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); }
- if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); }
- if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
- if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
+ if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); }
+ if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); }
+ if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); }
+ if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); }
+ if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
+ if (vocab.special_cls_id != -1) { LLAMA_LOG_INFO( "%s: CLS token = %d '%s'\n", __func__, vocab.special_cls_id, vocab.id_to_token[vocab.special_cls_id].text.c_str() ); }
+ if (vocab.special_mask_id != -1) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
+ if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
}
// Returns false if cancelled by progress_callback
output.push_back(vocab.special_unk_id);
}
}
-
- // append eos token
- output.push_back(vocab.special_eos_id);
}
std::vector<std::string> preprocess(const std::string & text) {
}
}
-static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special) {
+static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) {
std::vector<llama_vocab::id> output;
-
- // OG tokenizer behavior:
- //
- // tokenizer.encode('', add_bos=True) returns [1]
- // tokenizer.encode('', add_bos=False) returns []
-
- if (bos && vocab.special_bos_id != -1) {
- output.push_back(vocab.special_bos_id);
- }
-
- if (raw_text.empty()) {
- return output;
- }
-
std::forward_list<fragment_buffer_variant> fragment_buffer;
- fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
- if (special) tokenizer_st_partition(vocab, fragment_buffer);
+ if (!raw_text.empty()) {
+ fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
+ if (parse_special) tokenizer_st_partition(vocab, fragment_buffer);
+ }
switch (vocab.type) {
case LLAMA_VOCAB_TYPE_SPM:
{
+ // OG tokenizer behavior:
+ //
+ // tokenizer.encode('', add_special_tokens=True) returns [1]
+ // tokenizer.encode('', add_special_tokens=False) returns []
+
+ if (add_special && vocab.special_add_bos != 0) {
+ GGML_ASSERT(vocab.special_bos_id != -1);
+ output.push_back(vocab.special_bos_id);
+ }
+
for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
// without adding this leading whitespace, we do not get the same results as the original tokenizer
output.push_back(fragment.token);
}
}
+
+ if (add_special && vocab.special_add_eos == 1) {
+ GGML_ASSERT(vocab.special_eos_id != -1);
+ output.push_back(vocab.special_eos_id);
+ }
} break;
case LLAMA_VOCAB_TYPE_BPE:
{
+ if (add_special && vocab.special_add_bos == 1) {
+ GGML_ASSERT(vocab.special_bos_id != -1);
+ output.push_back(vocab.special_bos_id);
+ }
+
for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
output.push_back(fragment.token);
}
}
+
+ GGML_ASSERT(vocab.special_add_eos != 1);
} break;
case LLAMA_VOCAB_TYPE_WPM:
{
+ if (add_special) {
+ GGML_ASSERT(vocab.special_cls_id != -1);
+ output.push_back(vocab.special_cls_id);
+ }
+
for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
output.push_back(fragment.token);
}
}
+
+ if (add_special) {
+ GGML_ASSERT(vocab.special_sep_id != -1);
+ output.push_back(vocab.special_sep_id);
+ }
} break;
case LLAMA_VOCAB_TYPE_NONE:
GGML_ASSERT(false);
return model->vocab.special_eos_id;
}
+llama_token llama_token_cls(const struct llama_model * model) {
+ return model->vocab.special_cls_id;
+}
+
+llama_token llama_token_sep(const struct llama_model * model) {
+ return model->vocab.special_sep_id;
+}
+
llama_token llama_token_nl(const struct llama_model * model) {
return model->vocab.linefeed_id;
}
int32_t text_len,
llama_token * tokens,
int32_t n_tokens_max,
- bool add_bos,
- bool special) {
- auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special);
+ bool add_special,
+ bool parse_special) {
+ auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special);
if (n_tokens_max < (int) res.size()) {
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
// Special tokens
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
+ LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
+ LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
// Returns -1 if unknown, 1 for true or 0 for false.
/// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
/// @return Returns the number of tokens on success, no more than n_tokens_max
/// @return Returns a negative number on failure - the number of tokens that would have been returned
- /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
- /// Does not insert a leading space.
+ /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
+ /// as plaintext. Does not insert a leading space.
LLAMA_API int32_t llama_tokenize(
const struct llama_model * model,
const char * text,
int32_t text_len,
llama_token * tokens,
int32_t n_tokens_max,
- bool add_bos,
- bool special);
+ bool add_special,
+ bool parse_special);
// Token Id -> Piece.
// Uses the vocabulary in the provided context.