common/ngram-cache.o \
common/sampling.o \
common/train.o \
- common/grammar-parser.o \
common/build-info.o \
common/json-schema-to-grammar.o
common/console.h
$(CXX) $(CXXFLAGS) -c $< -o $@
-common/grammar-parser.o: \
- common/grammar-parser.cpp \
- common/grammar-parser.h
- $(CXX) $(CXXFLAGS) -c $< -o $@
-
common/json-schema-to-grammar.o: \
common/json-schema-to-grammar.cpp \
common/json-schema-to-grammar.h
sampling.cpp
console.h
console.cpp
- grammar-parser.h
- grammar-parser.cpp
json.hpp
json-schema-to-grammar.cpp
train.h
}
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
- bool invalid_param = false;
- std::string arg;
- const std::string arg_prefix = "--";
- llama_sampling_params & sparams = params.sparams;
-
for (int i = 1; i < argc; i++) {
- arg = argv[i];
+ const std::string arg_prefix = "--";
+
+ std::string arg = argv[i];
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
std::replace(arg.begin(), arg.end(), '_', '-');
}
+
+ bool invalid_param = false;
if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) {
throw std::invalid_argument("error: unknown argument: " + arg);
}
get_env("HF_TOKEN", params.hf_token);
}
+ auto & sparams = params.sparams;
+
if (params.escape) {
string_process_escapes(params.prompt);
string_process_escapes(params.input_prefix);
string_process_escapes(params.input_suffix);
- string_process_escapes(sparams.cfg_negative_prompt);
for (auto & antiprompt : params.antiprompt) {
string_process_escapes(antiprompt);
}
params.kv_overrides.back().key[0] = 0;
}
+ if (sparams.seed == LLAMA_DEFAULT_SEED) {
+ sparams.seed = time(NULL);
+ }
+
return true;
}
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
const char split_delim = ',';
- llama_sampling_params & sparams = params.sparams;
+ auto & sparams = params.sparams;
if (arg == "-s" || arg == "--seed") {
CHECK_ARG
- // TODO: this is temporary, in the future the sampling state will be moved fully to llama_sampling_context.
- params.seed = std::stoul(argv[i]);
sparams.seed = std::stoul(argv[i]);
return true;
}
if (arg == "--samplers") {
CHECK_ARG
const auto sampler_names = string_split(argv[i], ';');
- sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, true);
+ sparams.samplers = gpt_sampler_types_from_names(sampler_names, true);
return true;
}
if (arg == "--sampling-seq") {
CHECK_ARG
- sparams.samplers_sequence = llama_sampling_types_from_chars(argv[i]);
+ sparams.samplers = gpt_sampler_types_from_chars(argv[i]);
return true;
}
if (arg == "--top-p") {
}
if (arg == "--typical") {
CHECK_ARG
- sparams.typical_p = std::stof(argv[i]);
+ sparams.typ_p = std::stof(argv[i]);
return true;
}
if (arg == "--repeat-last-n") {
sparams.mirostat_tau = std::stof(argv[i]);
return true;
}
- if (arg == "--cfg-negative-prompt") {
- CHECK_ARG
- sparams.cfg_negative_prompt = argv[i];
- return true;
- }
- if (arg == "--cfg-negative-prompt-file") {
- CHECK_ARG
- std::ifstream file(argv[i]);
- if (!file) {
- fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
- invalid_param = true;
- return true;
- }
- std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(sparams.cfg_negative_prompt));
- if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') {
- sparams.cfg_negative_prompt.pop_back();
- }
- return true;
- }
- if (arg == "--cfg-scale") {
- CHECK_ARG
- sparams.cfg_scale = std::stof(argv[i]);
- return true;
- }
if (arg == "-b" || arg == "--batch-size") {
CHECK_ARG
params.n_batch = std::stoi(argv[i]);
return true;
}
if (arg == "--ignore-eos") {
- params.ignore_eos = true;
+ sparams.ignore_eos = true;
return true;
}
if (arg == "--penalize-nl") {
std::string value_str;
try {
if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
- sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
+ const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
+ sparams.logit_bias.push_back({key, bias});
}
else {
throw std::exception();
#endif
void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
- const llama_sampling_params & sparams = params.sparams;
+ const auto & sparams = params.sparams;
std::string sampler_type_chars;
std::string sampler_type_names;
- for (const auto sampler_type : sparams.samplers_sequence) {
- sampler_type_chars += static_cast<char>(sampler_type);
- sampler_type_names += llama_sampling_type_to_str(sampler_type) + ";";
+ for (const auto & sampler : sparams.samplers) {
+ sampler_type_chars += gpt_sampler_type_to_chr(sampler);
+ sampler_type_names += gpt_sampler_type_to_str(sampler) + ";";
}
sampler_type_names.pop_back();
options.push_back({ "*", " --verbose-prompt", "print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false" });
options.push_back({ "*", " --no-display-prompt", "don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false" });
options.push_back({ "*", "-co, --color", "colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false" });
- options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", params.seed });
options.push_back({ "*", "-t, --threads N", "number of threads to use during generation (default: %d)", params.cpuparams.n_threads });
options.push_back({ "*", "-tb, --threads-batch N", "number of threads to use during batch and prompt processing (default: same as --threads)" });
options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" });
" --spm-infill", "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", params.spm_infill ? "enabled" : "disabled" });
options.push_back({ "sampling" });
+ options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", sparams.seed });
options.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n"
"(default: %s)", sampler_type_names.c_str() });
options.push_back({ "*", " --sampling-seq SEQUENCE",
"simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() });
options.push_back({ "*", " --ignore-eos", "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)" });
options.push_back({ "*", " --penalize-nl", "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" });
- options.push_back({ "*", " --temp N", "temperature (default: %.1f)", (double)sparams.temp });
+ options.push_back({ "*", " --temp T", "temperature (default: %.1f)", (double)sparams.temp });
options.push_back({ "*", " --top-k N", "top-k sampling (default: %d, 0 = disabled)", sparams.top_k });
- options.push_back({ "*", " --top-p N", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p });
- options.push_back({ "*", " --min-p N", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p });
- options.push_back({ "*", " --tfs N", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z });
- options.push_back({ "*", " --typical N", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typical_p });
+ options.push_back({ "*", " --top-p P", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p });
+ options.push_back({ "*", " --min-p P", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p });
+ options.push_back({ "*", " --tfs P", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z });
+ options.push_back({ "*", " --typical P", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typ_p });
options.push_back({ "*", " --repeat-last-n N", "last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", sparams.penalty_last_n });
options.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat });
options.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present });
options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n"
"i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"
"or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" });
- options.push_back({ "main", " --cfg-negative-prompt PROMPT",
- "negative prompt to use for guidance (default: '%s')", sparams.cfg_negative_prompt.c_str() });
- options.push_back({ "main", " --cfg-negative-prompt-file FNAME",
- "negative prompt file to use for guidance" });
- options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale });
options.push_back({ "main", " --chat-template JINJA_TEMPLATE",
"set custom jinja chat template (default: template taken from model's metadata)\n"
"if suffix/prefix are specified, template will be disabled\n"
llama_lora_adapters_apply(lctx, iparams.lora_adapters);
}
- if (params.ignore_eos) {
- params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
+ if (params.sparams.ignore_eos && llama_token_eos(model) == -1) {
+ fprintf(stderr, "%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
+ params.sparams.ignore_eos = false;
}
if (params.warmup) {
}
llama_kv_cache_clear(lctx);
llama_synchronize(lctx);
- llama_reset_timings(lctx);
+ llama_perf_reset(lctx, LLAMA_PERF_TYPE_CONTEXT);
}
iparams.model = model;
cparams.n_threads = params.cpuparams.n_threads;
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
- cparams.seed = params.seed;
cparams.logits_all = params.logits_all;
cparams.embeddings = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type;
void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
- const llama_sampling_params & sparams = params.sparams;
+ const auto & sparams = params.sparams;
fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT);
fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER);
fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str());
fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch);
- yaml_dump_string_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str());
- fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale);
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
-
- const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
- const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
- fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
+ fprintf(stream, "ignore_eos: %s # default: false\n", sparams.ignore_eos ? "true" : "false");
yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str());
fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false");
fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str());
fprintf(stream, "logit_bias:\n");
- for (std::pair<llama_token, float> lb : sparams.logit_bias) {
- if (ignore_eos && lb.first == logit_bias_eos->first) {
- continue;
- }
- fprintf(stream, " %d: %f", lb.first, lb.second);
+ for (const auto & logit_bias : sparams.logit_bias) {
+ fprintf(stream, " %d: %f", logit_bias.token, logit_bias.bias);
}
fprintf(stream, "lora:\n");
fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base);
fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale);
- fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
- fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
+ fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
}
};
struct gpt_params {
- uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed
-
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 0; // context size
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
- // // sampling parameters
- struct llama_sampling_params sparams;
+ struct gpt_sampler_params sparams;
std::string model = ""; // model path
std::string model_draft = ""; // draft model for speculative decoding
bool flash_attn = false; // flash attention
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
- bool ignore_eos = false; // ignore generated EOS tokens
bool logits_all = false; // return logits for all tokens in the batch
bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory
+++ /dev/null
-#include "grammar-parser.h"
-#include <cstdint>
-#include <cwchar>
-#include <string>
-#include <utility>
-#include <stdexcept>
-#include <exception>
-
-namespace grammar_parser {
- // NOTE: assumes valid utf8 (but checks for overrun)
- // copied from llama.cpp
- static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
- static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
- uint8_t first_byte = static_cast<uint8_t>(*src);
- uint8_t highbits = first_byte >> 4;
- int len = lookup[highbits];
- uint8_t mask = (1 << (8 - len)) - 1;
- uint32_t value = first_byte & mask;
- const char * end = src + len; // may overrun!
- const char * pos = src + 1;
- for ( ; pos < end && *pos; pos++) {
- value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
- }
- return std::make_pair(value, pos);
- }
-
- static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
- uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
- auto result = state.symbol_ids.emplace(std::string(src, len), next_id);
- return result.first->second;
- }
-
- static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
- uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
- state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
- return next_id;
- }
-
- static void add_rule(
- parse_state & state,
- uint32_t rule_id,
- const std::vector<llama_grammar_element> & rule) {
- if (state.rules.size() <= rule_id) {
- state.rules.resize(rule_id + 1);
- }
- state.rules[rule_id] = rule;
- }
-
- static bool is_digit_char(char c) {
- return '0' <= c && c <= '9';
- }
-
- static bool is_word_char(char c) {
- return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
- }
-
- static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
- const char * pos = src;
- const char * end = src + size;
- uint32_t value = 0;
- for ( ; pos < end && *pos; pos++) {
- value <<= 4;
- char c = *pos;
- if ('a' <= c && c <= 'f') {
- value += c - 'a' + 10;
- } else if ('A' <= c && c <= 'F') {
- value += c - 'A' + 10;
- } else if ('0' <= c && c <= '9') {
- value += c - '0';
- } else {
- break;
- }
- }
- if (pos != end) {
- throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
- }
- return std::make_pair(value, pos);
- }
-
- static const char * parse_space(const char * src, bool newline_ok) {
- const char * pos = src;
- while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
- (newline_ok && (*pos == '\r' || *pos == '\n'))) {
- if (*pos == '#') {
- while (*pos && *pos != '\r' && *pos != '\n') {
- pos++;
- }
- } else {
- pos++;
- }
- }
- return pos;
- }
-
- static const char * parse_name(const char * src) {
- const char * pos = src;
- while (is_word_char(*pos)) {
- pos++;
- }
- if (pos == src) {
- throw std::runtime_error(std::string("expecting name at ") + src);
- }
- return pos;
- }
-
- static const char * parse_int(const char * src) {
- const char * pos = src;
- while (is_digit_char(*pos)) {
- pos++;
- }
- if (pos == src) {
- throw std::runtime_error(std::string("expecting integer at ") + src);
- }
- return pos;
- }
-
- static std::pair<uint32_t, const char *> parse_char(const char * src) {
- if (*src == '\\') {
- switch (src[1]) {
- case 'x': return parse_hex(src + 2, 2);
- case 'u': return parse_hex(src + 2, 4);
- case 'U': return parse_hex(src + 2, 8);
- case 't': return std::make_pair('\t', src + 2);
- case 'r': return std::make_pair('\r', src + 2);
- case 'n': return std::make_pair('\n', src + 2);
- case '\\':
- case '"':
- case '[':
- case ']':
- return std::make_pair(src[1], src + 2);
- default:
- throw std::runtime_error(std::string("unknown escape at ") + src);
- }
- } else if (*src) {
- return decode_utf8(src);
- }
- throw std::runtime_error("unexpected end of input");
- }
-
- const char * parse_alternates(
- parse_state & state,
- const char * src,
- const std::string & rule_name,
- uint32_t rule_id,
- bool is_nested);
-
- static const char * parse_sequence(
- parse_state & state,
- const char * src,
- const std::string & rule_name,
- std::vector<llama_grammar_element> & out_elements,
- bool is_nested) {
- size_t last_sym_start = out_elements.size();
- const char * pos = src;
-
- auto handle_repetitions = [&](int min_times, int max_times) {
-
- if (last_sym_start == out_elements.size()) {
- throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
- }
-
- // apply transformation to previous symbol (last_sym_start to end) according to
- // the following rewrite rules:
- // S{m,n} --> S S S (m times) S'(n-m)
- // S'(x) ::= S S'(x-1) |
- // (... n-m definitions of these S' rules ...)
- // S'(1) ::= S |
- // S{m,} --> S S S (m times) S'
- // S' ::= S S' |
- // S* --> S{0,}
- // --> S' ::= S S' |
- // S+ --> S{1,}
- // --> S S'
- // S' ::= S S' |
- // S? --> S{0,1}
- // --> S'
- // S' ::= S |
-
- std::vector<llama_grammar_element> previous_elements(out_elements.begin() + last_sym_start, out_elements.end());
- if (min_times == 0) {
- out_elements.resize(last_sym_start);
- } else {
- // Repeat the previous elements (min_times - 1) times
- for (int i = 1; i < min_times; i++) {
- out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end());
- }
- }
-
- uint32_t last_rec_rule_id = 0;
- auto n_opt = max_times < 0 ? 1 : max_times - min_times;
-
- std::vector<llama_grammar_element> rec_rule(previous_elements);
- for (int i = 0; i < n_opt; i++) {
- rec_rule.resize(previous_elements.size());
- uint32_t rec_rule_id = generate_symbol_id(state, rule_name);
- if (i > 0 || max_times < 0) {
- rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
- }
- rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
- rec_rule.push_back({LLAMA_GRETYPE_END, 0});
- add_rule(state, rec_rule_id, rec_rule);
- last_rec_rule_id = rec_rule_id;
- }
- if (n_opt > 0) {
- out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
- }
- };
-
- while (*pos) {
- if (*pos == '"') { // literal string
- pos++;
- last_sym_start = out_elements.size();
- while (*pos != '"') {
- if (!*pos) {
- throw std::runtime_error("unexpected end of input");
- }
- auto char_pair = parse_char(pos);
- pos = char_pair.second;
- out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
- }
- pos = parse_space(pos + 1, is_nested);
- } else if (*pos == '[') { // char range(s)
- pos++;
- enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
- if (*pos == '^') {
- pos++;
- start_type = LLAMA_GRETYPE_CHAR_NOT;
- }
- last_sym_start = out_elements.size();
- while (*pos != ']') {
- if (!*pos) {
- throw std::runtime_error("unexpected end of input");
- }
- auto char_pair = parse_char(pos);
- pos = char_pair.second;
- enum llama_gretype type = last_sym_start < out_elements.size()
- ? LLAMA_GRETYPE_CHAR_ALT
- : start_type;
-
- out_elements.push_back({type, char_pair.first});
- if (pos[0] == '-' && pos[1] != ']') {
- if (!pos[1]) {
- throw std::runtime_error("unexpected end of input");
- }
- auto endchar_pair = parse_char(pos + 1);
- pos = endchar_pair.second;
- out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
- }
- }
- pos = parse_space(pos + 1, is_nested);
- } else if (is_word_char(*pos)) { // rule reference
- const char * name_end = parse_name(pos);
- uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
- pos = parse_space(name_end, is_nested);
- last_sym_start = out_elements.size();
- out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
- } else if (*pos == '(') { // grouping
- // parse nested alternates into synthesized rule
- pos = parse_space(pos + 1, true);
- uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
- pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
- last_sym_start = out_elements.size();
- // output reference to synthesized rule
- out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
- if (*pos != ')') {
- throw std::runtime_error(std::string("expecting ')' at ") + pos);
- }
- pos = parse_space(pos + 1, is_nested);
- } else if (*pos == '.') { // any char
- last_sym_start = out_elements.size();
- out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
- pos = parse_space(pos + 1, is_nested);
- } else if (*pos == '*') {
- pos = parse_space(pos + 1, is_nested);
- handle_repetitions(0, -1);
- } else if (*pos == '+') {
- pos = parse_space(pos + 1, is_nested);
- handle_repetitions(1, -1);
- } else if (*pos == '?') {
- pos = parse_space(pos + 1, is_nested);
- handle_repetitions(0, 1);
- } else if (*pos == '{') {
- pos = parse_space(pos + 1, is_nested);
-
- if (!is_digit_char(*pos)) {
- throw std::runtime_error(std::string("expecting an int at ") + pos);
- }
- const char * int_end = parse_int(pos);
- int min_times = std::stoul(std::string(pos, int_end - pos));
- pos = parse_space(int_end, is_nested);
-
- int max_times = -1;
-
- if (*pos == '}') {
- max_times = min_times;
- pos = parse_space(pos + 1, is_nested);
- } else if (*pos == ',') {
- pos = parse_space(pos + 1, is_nested);
-
- if (is_digit_char(*pos)) {
- const char * int_end = parse_int(pos);
- max_times = std::stoul(std::string(pos, int_end - pos));
- pos = parse_space(int_end, is_nested);
- }
-
- if (*pos != '}') {
- throw std::runtime_error(std::string("expecting '}' at ") + pos);
- }
- pos = parse_space(pos + 1, is_nested);
- } else {
- throw std::runtime_error(std::string("expecting ',' at ") + pos);
- }
- handle_repetitions(min_times, max_times);
- } else {
- break;
- }
- }
- return pos;
- }
-
- const char * parse_alternates(
- parse_state & state,
- const char * src,
- const std::string & rule_name,
- uint32_t rule_id,
- bool is_nested) {
- std::vector<llama_grammar_element> rule;
- const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
- while (*pos == '|') {
- rule.push_back({LLAMA_GRETYPE_ALT, 0});
- pos = parse_space(pos + 1, true);
- pos = parse_sequence(state, pos, rule_name, rule, is_nested);
- }
- rule.push_back({LLAMA_GRETYPE_END, 0});
- add_rule(state, rule_id, rule);
- return pos;
- }
-
- static const char * parse_rule(parse_state & state, const char * src) {
- const char * name_end = parse_name(src);
- const char * pos = parse_space(name_end, false);
- size_t name_len = name_end - src;
- uint32_t rule_id = get_symbol_id(state, src, name_len);
- const std::string name(src, name_len);
-
- if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
- throw std::runtime_error(std::string("expecting ::= at ") + pos);
- }
- pos = parse_space(pos + 3, true);
-
- pos = parse_alternates(state, pos, name, rule_id, false);
-
- if (*pos == '\r') {
- pos += pos[1] == '\n' ? 2 : 1;
- } else if (*pos == '\n') {
- pos++;
- } else if (*pos) {
- throw std::runtime_error(std::string("expecting newline or end at ") + pos);
- }
- return parse_space(pos, true);
- }
-
- parse_state parse(const char * src) {
- try {
- parse_state state;
- const char * pos = parse_space(src, true);
- while (*pos) {
- pos = parse_rule(state, pos);
- }
- // Validate the state to ensure that all rules are defined
- for (const auto & rule : state.rules) {
- if (rule.empty()) {
- throw std::runtime_error("Undefined rule");
- }
- for (const auto & elem : rule) {
- if (elem.type == LLAMA_GRETYPE_RULE_REF) {
- // Ensure that the rule at that location exists
- if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) {
- // Get the name of the rule that is missing
- for (const auto & kv : state.symbol_ids) {
- if (kv.second == elem.value) {
- throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
- }
- }
- }
- }
- }
- }
- return state;
- } catch (const std::exception & err) {
- fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
- return parse_state();
- }
- }
-
- static void print_grammar_char(FILE * file, uint32_t c) {
- if (0x20 <= c && c <= 0x7f) {
- fprintf(file, "%c", static_cast<char>(c));
- } else {
- // cop out of encoding UTF-8
- fprintf(file, "<U+%04X>", c);
- }
- }
-
- static bool is_char_element(llama_grammar_element elem) {
- switch (elem.type) {
- case LLAMA_GRETYPE_CHAR: return true;
- case LLAMA_GRETYPE_CHAR_NOT: return true;
- case LLAMA_GRETYPE_CHAR_ALT: return true;
- case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
- case LLAMA_GRETYPE_CHAR_ANY: return true;
- default: return false;
- }
- }
-
- static void print_rule_binary(FILE * file, const std::vector<llama_grammar_element> & rule) {
- for (auto elem : rule) {
- switch (elem.type) {
- case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
- case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
- case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
- case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
- case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
- case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
- case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
- case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
- }
- switch (elem.type) {
- case LLAMA_GRETYPE_END:
- case LLAMA_GRETYPE_ALT:
- case LLAMA_GRETYPE_RULE_REF:
- fprintf(file, "(%u) ", elem.value);
- break;
- case LLAMA_GRETYPE_CHAR:
- case LLAMA_GRETYPE_CHAR_NOT:
- case LLAMA_GRETYPE_CHAR_RNG_UPPER:
- case LLAMA_GRETYPE_CHAR_ALT:
- case LLAMA_GRETYPE_CHAR_ANY:
- fprintf(file, "(\"");
- print_grammar_char(file, elem.value);
- fprintf(file, "\") ");
- break;
- }
- }
- fprintf(file, "\n");
- }
-
- static void print_rule(
- FILE * file,
- uint32_t rule_id,
- const std::vector<llama_grammar_element> & rule,
- const std::map<uint32_t, std::string> & symbol_id_names) {
- if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
- throw std::runtime_error(
- "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
- }
- fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
- for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
- llama_grammar_element elem = rule[i];
- switch (elem.type) {
- case LLAMA_GRETYPE_END:
- throw std::runtime_error(
- "unexpected end of rule: " + std::to_string(rule_id) + "," +
- std::to_string(i));
- case LLAMA_GRETYPE_ALT:
- fprintf(file, "| ");
- break;
- case LLAMA_GRETYPE_RULE_REF:
- fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
- break;
- case LLAMA_GRETYPE_CHAR:
- fprintf(file, "[");
- print_grammar_char(file, elem.value);
- break;
- case LLAMA_GRETYPE_CHAR_NOT:
- fprintf(file, "[^");
- print_grammar_char(file, elem.value);
- break;
- case LLAMA_GRETYPE_CHAR_RNG_UPPER:
- if (i == 0 || !is_char_element(rule[i - 1])) {
- throw std::runtime_error(
- "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
- std::to_string(rule_id) + "," + std::to_string(i));
- }
- fprintf(file, "-");
- print_grammar_char(file, elem.value);
- break;
- case LLAMA_GRETYPE_CHAR_ALT:
- if (i == 0 || !is_char_element(rule[i - 1])) {
- throw std::runtime_error(
- "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
- std::to_string(rule_id) + "," + std::to_string(i));
- }
- print_grammar_char(file, elem.value);
- break;
- case LLAMA_GRETYPE_CHAR_ANY:
- fprintf(file, ".");
- break;
- }
- if (is_char_element(elem)) {
- switch (rule[i + 1].type) {
- case LLAMA_GRETYPE_CHAR_ALT:
- case LLAMA_GRETYPE_CHAR_RNG_UPPER:
- case LLAMA_GRETYPE_CHAR_ANY:
- break;
- default:
- fprintf(file, "] ");
- }
- }
- }
- fprintf(file, "\n");
- }
-
- void print_grammar(FILE * file, const parse_state & state) {
- try {
- std::map<uint32_t, std::string> symbol_id_names;
- for (const auto & kv : state.symbol_ids) {
- symbol_id_names[kv.second] = kv.first;
- }
- for (size_t i = 0, end = state.rules.size(); i < end; i++) {
- // fprintf(file, "%zu: ", i);
- // print_rule_binary(file, state.rules[i]);
- print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
- // fprintf(file, "\n");
- }
- } catch (const std::exception & err) {
- fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
- }
- }
-
- std::vector<const llama_grammar_element *> parse_state::c_rules() {
- std::vector<const llama_grammar_element *> ret;
- ret.reserve(rules.size());
- for (const auto & rule : rules) {
- ret.push_back(rule.data());
- }
- return ret;
- }
-}
+++ /dev/null
-// Implements a parser for an extended Backus-Naur form (BNF), producing the
-// binary context-free grammar format specified by llama.h. Supports character
-// ranges, grouping, and repetition operators. As an example, a grammar for
-// arithmetic might look like:
-//
-// root ::= expr
-// expr ::= term ([-+*/] term)*
-// term ::= num | "(" space expr ")" space
-// num ::= [0-9]+ space
-// space ::= [ \t\n]*
-
-#pragma once
-#include "llama.h"
-#include <vector>
-#include <map>
-#include <cstdint>
-#include <string>
-
-namespace grammar_parser {
- struct parse_state {
- std::map<std::string, uint32_t> symbol_ids;
- std::vector<std::vector<llama_grammar_element>> rules;
-
- std::vector<const llama_grammar_element *> c_rules();
- };
-
- parse_state parse(const char * src);
- void print_grammar(FILE * file, const parse_state & state);
-}
-#define LLAMA_API_INTERNAL
#include "sampling.h"
-#include <random>
-struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
- struct llama_sampling_context * result = new llama_sampling_context();
+#include "common.h"
- result->params = params;
- result->grammar = nullptr;
+// the ring buffer works similarly to std::deque, but with a fixed capacity
+// TODO: deduplicate with llama-impl.h
+template<typename T>
+struct ring_buffer {
+ ring_buffer(size_t cap) : capacity(cap), data(cap) {}
- // if there is a grammar, parse it
- if (!params.grammar.empty()) {
- result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
-
- // will be empty (default) if there are parse errors
- if (result->parsed_grammar.rules.empty()) {
- fprintf(stderr, "%s: failed to parse grammar\n", __func__);
- delete result;
- return nullptr;
+ T & front() {
+ if (sz == 0) {
+ throw std::runtime_error("ring buffer is empty");
}
+ return data[first];
+ }
- // Ensure that there is a "root" node.
- if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) {
- fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
- delete result;
- return nullptr;
+ const T & front() const {
+ if (sz == 0) {
+ throw std::runtime_error("ring buffer is empty");
}
+ return data[first];
+ }
- std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
-
- struct llama_grammar * grammar = llama_grammar_init(
- grammar_rules.data(),
- grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
- if (grammar == nullptr) {
- throw std::runtime_error("Failed to initialize llama_grammar");
+ T & back() {
+ if (sz == 0) {
+ throw std::runtime_error("ring buffer is empty");
}
- result->grammar = grammar;
+ return data[pos];
}
- result->prev.resize(params.n_prev);
-
- result->n_valid = 0;
-
- llama_sampling_set_rng_seed(result, params.seed);
-
- return result;
-}
-
-void llama_sampling_free(struct llama_sampling_context * ctx) {
- if (ctx->grammar != NULL) {
- llama_grammar_free(ctx->grammar);
+ const T & back() const {
+ if (sz == 0) {
+ throw std::runtime_error("ring buffer is empty");
+ }
+ return data[pos];
}
- delete ctx;
-}
-
-void llama_sampling_reset(llama_sampling_context * ctx) {
- if (ctx->grammar != NULL) {
- llama_grammar_free(ctx->grammar);
- ctx->grammar = NULL;
+ void push_back(const T & value) {
+ if (sz == capacity) {
+ // advance the start when buffer is full
+ first = (first + 1) % capacity;
+ } else {
+ sz++;
+ }
+ data[pos] = value;
+ pos = (pos + 1) % capacity;
}
- if (!ctx->parsed_grammar.rules.empty()) {
- std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
+ T pop_front() {
+ if (sz == 0) {
+ throw std::runtime_error("ring buffer is empty");
+ }
+ T value = data[first];
+ first = (first + 1) % capacity;
+ sz--;
+ return value;
+ }
- struct llama_grammar * grammar = llama_grammar_init(
- grammar_rules.data(),
- grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
- if (grammar == nullptr) {
- throw std::runtime_error("Failed to initialize llama_grammar");
+ const T & rat(size_t i) const {
+ if (i >= sz) {
+ throw std::runtime_error("ring buffer: index out of bounds");
}
- ctx->grammar = grammar;
+ return data[(first + sz - i - 1) % capacity];
}
- std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
- ctx->cur.clear();
- ctx->n_valid = 0;
-}
+ std::vector<T> to_vector() const {
+ std::vector<T> result;
+ result.reserve(sz);
+ for (size_t i = 0; i < sz; i++) {
+ result.push_back(data[(first + i) % capacity]);
+ }
+ return result;
+ }
-void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
- if (seed == LLAMA_DEFAULT_SEED) {
- seed = std::random_device{}();
+ void clear() {
+ // here only reset the status of the buffer
+ sz = 0;
+ first = 0;
+ pos = 0;
}
- ctx->rng.seed(seed);
-}
-void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
- if (dst->grammar) {
- llama_grammar_free(dst->grammar);
- dst->grammar = nullptr;
+ bool empty() const {
+ return sz == 0;
}
- if (src->grammar) {
- dst->grammar = llama_grammar_copy(src->grammar);
+ size_t size() const {
+ return sz;
}
- dst->prev = src->prev;
-}
+ size_t capacity = 0;
+ size_t sz = 0;
+ size_t first = 0;
+ size_t pos = 0;
+ std::vector<T> data;
+};
-llama_token llama_sampling_last(llama_sampling_context * ctx) {
- return ctx->prev.back();
-}
+struct gpt_sampler {
+ gpt_sampler_params params;
-std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) {
- const int size = ctx_sampling->prev.size();
+ struct llama_sampler * grmr;
+ struct llama_sampler * chain;
- n = std::min(n, size);
+ ring_buffer<llama_token> prev;
- std::string result;
+ std::vector<llama_token_data> cur;
- for (int i = size - n; i < size; i++) {
- result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]);
- }
+ llama_token_data_array cur_p;
- return result;
-}
+ void set_logits(struct llama_context * ctx, int idx) {
+ const auto * logits = llama_get_logits_ith(ctx, idx);
+
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
+
+ cur.resize(n_vocab);
+
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
+ }
+
+ cur_p = { cur.data(), cur.size(), -1, false };
+ }
+};
-std::string llama_sampling_print(const llama_sampling_params & params) {
+std::string gpt_sampler_params::print() const {
char result[1024];
snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
- params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
- params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
- params.mirostat, params.mirostat_eta, params.mirostat_tau);
+ penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
+ top_k, tfs_z, top_p, min_p, typ_p, temp,
+ mirostat, mirostat_eta, mirostat_tau);
return std::string(result);
}
-std::string llama_sampling_order_print(const llama_sampling_params & params) {
- std::string result = "CFG -> Penalties ";
- if (params.mirostat == 0) {
- for (auto sampler_type : params.samplers_sequence) {
- const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
- if (!sampler_type_name.empty()) {
- result += "-> " + sampler_type_name + " ";
+struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
+ llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
+
+ lparams.no_perf = false; // TODO: control via params
+
+ auto * result = new gpt_sampler {
+ /* .params = */ params,
+ /* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
+ /* .chain = */ llama_sampler_chain_init(lparams),
+ /* .prev = */ ring_buffer<llama_token>(params.n_prev),
+ /* .cur = */ {},
+ /* .cur_p = */ {},
+ };
+
+ llama_sampler_chain_add(result->chain,
+ llama_sampler_init_logit_bias(
+ llama_n_vocab(model),
+ params.logit_bias.size(),
+ params.logit_bias.data()));
+
+ llama_sampler_chain_add(result->chain,
+ llama_sampler_init_penalties(
+ llama_n_vocab (model),
+ llama_token_eos(model),
+ llama_token_nl (model),
+ params.penalty_last_n,
+ params.penalty_repeat,
+ params.penalty_freq,
+ params.penalty_present,
+ params.penalize_nl,
+ params.ignore_eos));
+
+ if (params.temp > 0.0f) {
+ if (params.mirostat == 0) {
+ for (const auto & cnstr : params.samplers) {
+ switch (cnstr) {
+ case GPT_SAMPLER_TYPE_TOP_K:
+ llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
+ break;
+ case GPT_SAMPLER_TYPE_TOP_P:
+ llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
+ break;
+ case GPT_SAMPLER_TYPE_MIN_P:
+ llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
+ break;
+ case GPT_SAMPLER_TYPE_TFS_Z:
+ llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
+ break;
+ case GPT_SAMPLER_TYPE_TYPICAL_P:
+ llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
+ break;
+ case GPT_SAMPLER_TYPE_TEMPERATURE:
+ llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
+ break;
+ default:
+ GGML_ASSERT(false && "unknown sampler type");
+ }
}
+ llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
+ llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
+ } else if (params.mirostat == 1) {
+ llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
+ llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
+ } else if (params.mirostat == 2) {
+ llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
+ llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
+ } else {
+ GGML_ASSERT(false && "unknown mirostat version");
}
} else {
- result += "-> mirostat ";
+ llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
+ llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
}
return result;
}
-std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
- switch (sampler_type) {
- case llama_sampler_type::TOP_K: return "top_k";
- case llama_sampler_type::TFS_Z: return "tfs_z";
- case llama_sampler_type::TYPICAL_P: return "typical_p";
- case llama_sampler_type::TOP_P: return "top_p";
- case llama_sampler_type::MIN_P: return "min_p";
- case llama_sampler_type::TEMPERATURE: return "temperature";
- default : return "";
+void gpt_sampler_free(struct gpt_sampler * gsmpl) {
+ if (gsmpl) {
+ llama_sampler_free(gsmpl->grmr);
+
+ llama_sampler_free(gsmpl->chain);
+
+ delete gsmpl;
}
}
-std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
- std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
- {"top_k", llama_sampler_type::TOP_K},
- {"top_p", llama_sampler_type::TOP_P},
- {"typical_p", llama_sampler_type::TYPICAL_P},
- {"min_p", llama_sampler_type::MIN_P},
- {"tfs_z", llama_sampler_type::TFS_Z},
- {"temperature", llama_sampler_type::TEMPERATURE}
- };
+void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar) {
+ if (accept_grammar) {
+ llama_sampler_accept(gsmpl->grmr, token);
+ }
- // since samplers names are written multiple ways
- // make it ready for both system names and input names
- std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
- {"top-k", llama_sampler_type::TOP_K},
- {"top-p", llama_sampler_type::TOP_P},
- {"nucleus", llama_sampler_type::TOP_P},
- {"typical-p", llama_sampler_type::TYPICAL_P},
- {"typical", llama_sampler_type::TYPICAL_P},
- {"min-p", llama_sampler_type::MIN_P},
- {"tfs-z", llama_sampler_type::TFS_Z},
- {"tfs", llama_sampler_type::TFS_Z},
- {"temp", llama_sampler_type::TEMPERATURE}
- };
+ llama_sampler_accept(gsmpl->chain, token);
- std::vector<llama_sampler_type> sampler_types;
- sampler_types.reserve(names.size());
- for (const auto & name : names)
- {
- auto sampler_item = sampler_canonical_name_map.find(name);
- if (sampler_item != sampler_canonical_name_map.end())
- {
- sampler_types.push_back(sampler_item->second);
- }
- else
- {
- if (allow_alt_names)
- {
- sampler_item = sampler_alt_name_map.find(name);
- if (sampler_item != sampler_alt_name_map.end())
- {
- sampler_types.push_back(sampler_item->second);
- }
- }
- }
- }
- return sampler_types;
+ gsmpl->prev.push_back(token);
}
-std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string) {
- std::unordered_map<char, llama_sampler_type> sampler_name_map {
- {'k', llama_sampler_type::TOP_K},
- {'p', llama_sampler_type::TOP_P},
- {'y', llama_sampler_type::TYPICAL_P},
- {'m', llama_sampler_type::MIN_P},
- {'f', llama_sampler_type::TFS_Z},
- {'t', llama_sampler_type::TEMPERATURE}
- };
+void gpt_sampler_reset(struct gpt_sampler * gsmpl) {
+ llama_sampler_reset(gsmpl->grmr);
- std::vector<llama_sampler_type> sampler_types;
- sampler_types.reserve(names_string.size());
- for (const auto & c : names_string) {
- const auto sampler_item = sampler_name_map.find(c);
- if (sampler_item != sampler_name_map.end()) {
- sampler_types.push_back(sampler_item->second);
- }
- }
- return sampler_types;
+ llama_sampler_reset(gsmpl->chain);
}
-// no reasons to expose this function in header
-static void sampler_queue(
- struct llama_context * ctx_main,
- const llama_sampling_params & params,
- llama_token_data_array & cur_p,
- size_t min_keep) {
- const float temp = params.temp;
- const float dynatemp_range = params.dynatemp_range;
- const float dynatemp_exponent = params.dynatemp_exponent;
- const int32_t top_k = params.top_k;
- const float top_p = params.top_p;
- const float min_p = params.min_p;
- const float tfs_z = params.tfs_z;
- const float typical_p = params.typical_p;
- const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
-
- for (auto sampler_type : samplers_sequence) {
- switch (sampler_type) {
- case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
- case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
- case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
- case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
- case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
- case llama_sampler_type::TEMPERATURE:
- if (dynatemp_range > 0) {
- float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
- float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
- llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
- } else {
- llama_sample_temp(ctx_main, &cur_p, temp);
- }
- break;
- default : break;
- }
- }
+struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
+ return new gpt_sampler {
+ /* .params = */ gsmpl->params,
+ /* .grmr = */ llama_sampler_clone(gsmpl->grmr),
+ /* .chain = */ llama_sampler_clone(gsmpl->chain),
+ /* .prev = */ gsmpl->prev,
+ /* .cur = */ gsmpl->cur,
+ /* .cur_p = */ gsmpl->cur_p,
+ };
}
-static llama_token llama_sampling_sample_impl(
- struct llama_sampling_context * ctx_sampling,
- struct llama_context * ctx_main,
- struct llama_context * ctx_cfg,
- const int idx,
- bool is_resampling) {
- const llama_sampling_params & params = ctx_sampling->params;
-
- const float temp = params.temp;
- const int mirostat = params.mirostat;
- const float mirostat_tau = params.mirostat_tau;
- const float mirostat_eta = params.mirostat_eta;
-
- std::vector<float> original_logits;
- auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
- if (ctx_sampling->grammar != NULL && !is_resampling) {
- GGML_ASSERT(!original_logits.empty());
- }
- llama_token id = 0;
-
- if (temp < 0.0) {
- // greedy sampling, with probs
- llama_sample_softmax(ctx_main, &cur_p);
- id = cur_p.data[0].id;
- } else if (temp == 0.0) {
- // greedy sampling, no probs
- id = llama_sample_token_greedy(ctx_main, &cur_p);
- } else {
- if (mirostat == 1) {
- const int mirostat_m = 100;
- llama_sample_temp(ctx_main, &cur_p, temp);
- id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
- } else if (mirostat == 2) {
- llama_sample_temp(ctx_main, &cur_p, temp);
- id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
- } else {
- // temperature sampling
- size_t min_keep = std::max(1, params.min_keep);
-
- sampler_queue(ctx_main, params, cur_p, min_keep);
+void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl) {
+ // TODO: measure grammar performance
- id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
+ if (gsmpl) {
+ llama_perf_print(gsmpl->chain, LLAMA_PERF_TYPE_SAMPLER_CHAIN);
+ }
+ if (ctx) {
+ llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
+ }
+}
- //{
- // const int n_top = 10;
- // LOG("top %d candidates:\n", n_top);
+llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
+ gsmpl->set_logits(ctx, idx);
- // for (int i = 0; i < n_top; i++) {
- // const llama_token id = cur_p.data[i].id;
- // (void)id; // To avoid a warning that id is unused when logging is disabled.
- // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
- // }
- //}
+ auto & grmr = gsmpl->grmr;
+ auto & chain = gsmpl->chain;
+ auto & cur_p = gsmpl->cur_p; // initialized by set_logits
- //LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str());
- }
+ if (grammar_first) {
+ llama_sampler_apply(grmr, &cur_p);
}
- if (ctx_sampling->grammar != NULL && !is_resampling) {
- // Get a pointer to the logits
- float * logits = llama_get_logits_ith(ctx_main, idx);
+ llama_sampler_apply(chain, &cur_p);
- // Create an array with a single token data element for the sampled id
- llama_token_data single_token_data = {id, logits[id], 0.0f};
- llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
+ GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
- // Apply grammar constraints to the single token
- llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);
+ const llama_token id = cur_p.data[cur_p.selected].id;
- // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
- bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
+ if (grammar_first) {
+ return id;
+ }
- // If the token is not valid according to the grammar, perform resampling
- if (!is_valid) {
- LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
+ // check if it the sampled token fits the grammar
+ {
+ llama_token_data single_token_data = { id, 1.0f, 0.0f };
+ llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
- // Restore logits from the copy
- std::copy(original_logits.begin(), original_logits.end(), logits);
+ llama_sampler_apply(grmr, &single_token_data_array);
- return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true);
+ const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
+ if (is_valid) {
+ return id;
}
}
- ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size;
+ // resampling:
+ // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
+ gsmpl->set_logits(ctx, idx);
- return id;
-}
+ llama_sampler_apply(grmr, &cur_p);
+ llama_sampler_apply(chain, &cur_p);
-static llama_token_data_array llama_sampling_prepare_impl(
- struct llama_sampling_context * ctx_sampling,
- struct llama_context * ctx_main,
- struct llama_context * ctx_cfg,
- const int idx,
- bool apply_grammar,
- std::vector<float> * original_logits) {
- const llama_sampling_params & params = ctx_sampling->params;
+ GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
- const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
+ return cur_p.data[cur_p.selected].id;
+}
- const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
- const float penalty_repeat = params.penalty_repeat;
- const float penalty_freq = params.penalty_freq;
- const float penalty_present = params.penalty_present;
+// helpers
- const bool penalize_nl = params.penalize_nl;
+llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {
+ return &gsmpl->cur_p;
+}
- auto & prev = ctx_sampling->prev;
- auto & cur = ctx_sampling->cur;
+llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) {
+ return gsmpl->prev.rat(0);
+}
- // Get a pointer to the logits
- float * logits = llama_get_logits_ith(ctx_main, idx);
+std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
+ std::string result = "\tlogits ";
- if (ctx_sampling->grammar != NULL && !apply_grammar) {
- GGML_ASSERT(original_logits != NULL);
- // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
- *original_logits = {logits, logits + n_vocab};
+ for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
+ const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
+ result += std::string("-> ") + llama_sampler_name(smpl) + " ";
}
- // apply params.logit_bias map
- for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
- logits[it->first] += it->second;
+ return result;
+}
+
+std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) {
+ n = std::min(n, (int) gsmpl->prev.size());
+
+ if (n <= 0) {
+ return "";
}
- if (ctx_cfg) {
- float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
- llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
+ std::string result;
+ result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab
+
+ for (int i = n - 1; i >= 0; i--) {
+ const llama_token id = gsmpl->prev.rat(i);
+
+ GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");
+
+ result += llama_token_to_piece(ctx_main, id);
}
- cur.resize(n_vocab);
+ return result;
+}
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
+char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr) {
+ switch (cnstr) {
+ case GPT_SAMPLER_TYPE_TOP_K: return 'k';
+ case GPT_SAMPLER_TYPE_TFS_Z: return 'f';
+ case GPT_SAMPLER_TYPE_TYPICAL_P: return 'y';
+ case GPT_SAMPLER_TYPE_TOP_P: return 'p';
+ case GPT_SAMPLER_TYPE_MIN_P: return 'm';
+ case GPT_SAMPLER_TYPE_TEMPERATURE: return 't';
+ default : return '?';
}
+}
- llama_token_data_array cur_p = { cur.data(), cur.size(), false };
+std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr) {
+ switch (cnstr) {
+ case GPT_SAMPLER_TYPE_TOP_K: return "top_k";
+ case GPT_SAMPLER_TYPE_TFS_Z: return "tfs_z";
+ case GPT_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
+ case GPT_SAMPLER_TYPE_TOP_P: return "top_p";
+ case GPT_SAMPLER_TYPE_MIN_P: return "min_p";
+ case GPT_SAMPLER_TYPE_TEMPERATURE: return "temperature";
+ default : return "";
+ }
+}
- // apply penalties
- const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
- const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
- if (penalty_tokens_used_size) {
- const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
+std::vector<gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
+ std::unordered_map<std::string, gpt_sampler_type> sampler_canonical_name_map {
+ { "top_k", GPT_SAMPLER_TYPE_TOP_K },
+ { "top_p", GPT_SAMPLER_TYPE_TOP_P },
+ { "typ_p", GPT_SAMPLER_TYPE_TYPICAL_P },
+ { "min_p", GPT_SAMPLER_TYPE_MIN_P },
+ { "tfs_z", GPT_SAMPLER_TYPE_TFS_Z },
+ { "temperature", GPT_SAMPLER_TYPE_TEMPERATURE },
+ };
- llama_sample_repetition_penalties(ctx_main, &cur_p,
- penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
- penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
+ // since samplers names are written multiple ways
+ // make it ready for both system names and input names
+ std::unordered_map<std::string, gpt_sampler_type> sampler_alt_name_map {
+ { "top-k", GPT_SAMPLER_TYPE_TOP_K },
+ { "top-p", GPT_SAMPLER_TYPE_TOP_P },
+ { "nucleus", GPT_SAMPLER_TYPE_TOP_P },
+ { "typical-p", GPT_SAMPLER_TYPE_TYPICAL_P },
+ { "typical", GPT_SAMPLER_TYPE_TYPICAL_P },
+ { "typ-p", GPT_SAMPLER_TYPE_TYPICAL_P },
+ { "typ", GPT_SAMPLER_TYPE_TYPICAL_P },
+ { "min-p", GPT_SAMPLER_TYPE_MIN_P },
+ { "tfs-z", GPT_SAMPLER_TYPE_TFS_Z },
+ { "tfs", GPT_SAMPLER_TYPE_TFS_Z },
+ { "temp", GPT_SAMPLER_TYPE_TEMPERATURE },
+ };
+
+ std::vector<gpt_sampler_type> samplers;
+ samplers.reserve(names.size());
- if (!penalize_nl) {
- for (size_t idx = 0; idx < cur_p.size; idx++) {
- if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
- cur_p.data[idx].logit = nl_logit;
- break;
+ for (const auto & name : names) {
+ auto sampler = sampler_canonical_name_map.find(name);
+ if (sampler != sampler_canonical_name_map.end()) {
+ samplers.push_back(sampler->second);
+ } else {
+ if (allow_alt_names) {
+ sampler = sampler_alt_name_map.find(name);
+ if (sampler != sampler_alt_name_map.end()) {
+ samplers.push_back(sampler->second);
}
}
}
}
- // apply grammar checks before sampling logic
- if (apply_grammar && ctx_sampling->grammar != NULL) {
- llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
- }
-
- return cur_p;
+ return samplers;
}
-llama_token llama_sampling_sample(
- struct llama_sampling_context * ctx_sampling,
- struct llama_context * ctx_main,
- struct llama_context * ctx_cfg,
- const int idx) {
- // Call the implementation function with is_resampling set to false by default
- return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false);
-}
-
-llama_token_data_array llama_sampling_prepare(
- struct llama_sampling_context * ctx_sampling,
- struct llama_context * ctx_main,
- struct llama_context * ctx_cfg,
- const int idx,
- bool apply_grammar,
- std::vector<float> * original_logits) {
- return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
-}
+std::vector<gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars) {
+ std::unordered_map<char, gpt_sampler_type> sampler_name_map {
+ { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_K), GPT_SAMPLER_TYPE_TOP_K },
+ { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TFS_Z), GPT_SAMPLER_TYPE_TFS_Z },
+ { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TYPICAL_P), GPT_SAMPLER_TYPE_TYPICAL_P },
+ { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_P), GPT_SAMPLER_TYPE_TOP_P },
+ { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_MIN_P), GPT_SAMPLER_TYPE_MIN_P },
+ { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TEMPERATURE), GPT_SAMPLER_TYPE_TEMPERATURE }
+ };
-void llama_sampling_accept(
- struct llama_sampling_context * ctx_sampling,
- struct llama_context * ctx_main,
- llama_token id,
- bool apply_grammar) {
- ctx_sampling->prev.erase(ctx_sampling->prev.begin());
- ctx_sampling->prev.push_back(id);
+ std::vector<gpt_sampler_type> samplers;
+ samplers.reserve(chars.size());
- if (ctx_sampling->grammar != NULL && apply_grammar) {
- llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
+ for (const auto & c : chars) {
+ const auto sampler = sampler_name_map.find(c);
+ if (sampler != sampler_name_map.end()) {
+ samplers.push_back(sampler->second);
+ }
}
+
+ return samplers;
}
#include "llama.h"
-#include "grammar-parser.h"
-
-#include <random>
#include <string>
-#include <unordered_map>
#include <vector>
-// sampler types
-enum class llama_sampler_type : char {
- TOP_K = 'k',
- TOP_P = 'p',
- MIN_P = 'm',
- TFS_Z = 'f',
- TYPICAL_P = 'y',
- TEMPERATURE = 't'
+enum gpt_sampler_type {
+ GPT_SAMPLER_TYPE_NONE = 0,
+ GPT_SAMPLER_TYPE_TOP_K = 1,
+ GPT_SAMPLER_TYPE_TOP_P = 2,
+ GPT_SAMPLER_TYPE_MIN_P = 3,
+ GPT_SAMPLER_TYPE_TFS_Z = 4,
+ GPT_SAMPLER_TYPE_TYPICAL_P = 5,
+ GPT_SAMPLER_TYPE_TEMPERATURE = 6,
};
// sampling parameters
-typedef struct llama_sampling_params {
- int32_t n_prev = 64; // number of previous tokens to remember
- int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
- int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
- int32_t top_k = 40; // <= 0 to use vocab size
- float top_p = 0.95f; // 1.0 = disabled
- float min_p = 0.05f; // 0.0 = disabled
- float tfs_z = 1.00f; // 1.0 = disabled
- float typical_p = 1.00f; // 1.0 = disabled
- float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
- float dynatemp_range = 0.00f; // 0.0 = disabled
- float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
- int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
- float penalty_repeat = 1.00f; // 1.0 = disabled
- float penalty_freq = 0.00f; // 0.0 = disabled
- float penalty_present = 0.00f; // 0.0 = disabled
- int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
- float mirostat_tau = 5.00f; // target entropy
- float mirostat_eta = 0.10f; // learning rate
- bool penalize_nl = false; // consider newlines as a repeatable token
- uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
-
- std::vector<llama_sampler_type> samplers_sequence = {
- llama_sampler_type::TOP_K,
- llama_sampler_type::TFS_Z,
- llama_sampler_type::TYPICAL_P,
- llama_sampler_type::TOP_P,
- llama_sampler_type::MIN_P,
- llama_sampler_type::TEMPERATURE
+struct gpt_sampler_params {
+ uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
+
+ int32_t n_prev = 64; // number of previous tokens to remember
+ int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
+ int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
+ int32_t top_k = 40; // <= 0 to use vocab size
+ float top_p = 0.95f; // 1.0 = disabled
+ float min_p = 0.05f; // 0.0 = disabled
+ float tfs_z = 1.00f; // 1.0 = disabled
+ float typ_p = 1.00f; // typical_p, 1.0 = disabled
+ float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
+ float dynatemp_range = 0.00f; // 0.0 = disabled
+ float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
+ int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
+ float penalty_repeat = 1.00f; // 1.0 = disabled
+ float penalty_freq = 0.00f; // 0.0 = disabled
+ float penalty_present = 0.00f; // 0.0 = disabled
+ int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+ float mirostat_tau = 5.00f; // target entropy
+ float mirostat_eta = 0.10f; // learning rate
+ bool penalize_nl = false; // consider newlines as a repeatable token
+ bool ignore_eos = false;
+
+ std::vector<enum gpt_sampler_type> samplers = {
+ GPT_SAMPLER_TYPE_TOP_K,
+ GPT_SAMPLER_TYPE_TFS_Z,
+ GPT_SAMPLER_TYPE_TYPICAL_P,
+ GPT_SAMPLER_TYPE_TOP_P,
+ GPT_SAMPLER_TYPE_MIN_P,
+ GPT_SAMPLER_TYPE_TEMPERATURE
};
- std::string grammar; // optional BNF-like grammar to constrain sampling
-
- // Classifier-Free Guidance
- // https://arxiv.org/abs/2306.17806
- std::string cfg_negative_prompt; // string to help guidance
- float cfg_scale = 1.f; // how strong is guidance
-
- std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
-
- std::vector<llama_token> penalty_prompt_tokens;
- bool use_penalty_prompt_tokens = false;
-} llama_sampling_params;
-
-// general sampler context
-// TODO: move to llama.h
-struct llama_sampling_context {
- // parameters that will be used for sampling
- llama_sampling_params params;
-
- // mirostat sampler state
- float mirostat_mu;
+ std::string grammar; // optional BNF-like grammar to constrain sampling
- llama_grammar * grammar;
+ std::vector<llama_logit_bias> logit_bias; // logit biases to apply
- // internal
- grammar_parser::parse_state parsed_grammar;
+ // print the parameters into a string
+ std::string print() const;
+};
- // TODO: replace with ring-buffer
- std::vector<llama_token> prev;
- std::vector<llama_token_data> cur;
- size_t n_valid; // Number of correct top tokens with correct probabilities.
+// gpt_sampler extends llama_sampler with additional functionality:
+//
+// - grammar support
+// - custom sampler logic based on the parameters
+// - history of the last accepted tokens
+// - performance metrics
+//
+// This goal is to have a common implementation of the sampling logic shared across the examples.
+// For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
+// complex (top-k, top-p, etc).
+//
+// Another example is related to the grammar. In general, the grammar constraints applied on the full
+// vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
+// token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
+// grammar constraints are applied to the full vocabulary and the token is resampled.
+//
+// The gpt_sampler also maintains a container with the last accepted tokens. In the future, this can
+// be moved into the core llama library.
+//
+// For convenience, the gpt_sampler also maintains a container with the current candidate tokens.
+// This can be used to access the probabilities of the rest of the non-sampled tokens.
+//
+// TODO: measure grammar performance
+//
- std::mt19937 rng;
-};
+struct gpt_sampler;
-#include "common.h"
+// llama_sampler API overloads
-// Create a new sampling context instance.
-struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
+struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params);
-void llama_sampling_free(struct llama_sampling_context * ctx);
+void gpt_sampler_free(struct gpt_sampler * gsmpl);
-// Reset the sampler context
-// - clear prev tokens
-// - reset grammar
-void llama_sampling_reset(llama_sampling_context * ctx);
+// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
+void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar);
+void gpt_sampler_reset (struct gpt_sampler * gsmpl);
+struct gpt_sampler * gpt_sampler_clone (struct gpt_sampler * gsmpl);
-// Set the sampler seed
-void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
+// arguments can be nullptr to skip printing
+void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl);
-// Copy the sampler context
-void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
+// extended sampling implementation:
+//
+// - set logits
+// - apply the configured sampler chain
+// - check if the token fits the grammar (if any)
+// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
+//
+// if grammar_first is true, the grammar is applied before the samplers (slower)
+// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
+//
+llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
-// Get the last sampled token
-llama_token llama_sampling_last(llama_sampling_context * ctx);
+// helpers
-// Get a string representation of the last sampled tokens
-std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
+// access the internal list of current candidate tokens
+llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
-// Print sampling parameters into a string
-std::string llama_sampling_print(const llama_sampling_params & params);
+// get the last accepted token
+llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
-// Print sampling order into a string
-std::string llama_sampling_order_print(const llama_sampling_params & params);
+// print the sampler chain into a string
+std::string gpt_sampler_print(const struct gpt_sampler * gsmpl);
-std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);
+// get a string representation of the last accepted tokens
+std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n);
-std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
-std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string);
+char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr);
+std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr);
-// this is a common sampling function used across the examples for convenience
-// it can serve as a starting point for implementing your own sampling function
-// Note: When using multiple sequences, it is the caller's responsibility to call
-// llama_sampling_reset when a sequence ends
-//
-// required:
-// - ctx_main: context to use for sampling
-// - ctx_sampling: sampling-specific context
-//
-// optional:
-// - ctx_cfg: context to use for classifier-free guidance
-// - idx: sample from llama_get_logits_ith(ctx, idx)
-//
-// returns:
-// - token: sampled token
-// - candidates: vector of candidate tokens
-//
-llama_token llama_sampling_sample(
- struct llama_sampling_context * ctx_sampling,
- struct llama_context * ctx_main,
- struct llama_context * ctx_cfg,
- int idx = -1);
-
-// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
-llama_token_data_array llama_sampling_prepare(
- struct llama_sampling_context * ctx_sampling,
- struct llama_context * ctx_main,
- struct llama_context * ctx_cfg,
- int idx = 0,
- bool apply_grammar = true,
- std::vector<float> * original_logits = nullptr);
-
-void llama_sampling_accept(
- struct llama_sampling_context * ctx_sampling,
- struct llama_context * ctx_main,
- llama_token id,
- bool apply_grammar);
+std::vector<enum gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
+std::vector<enum gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars);
}
}
- llama_print_timings(ctx);
+ LOG_TEE("\n");
+ llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
llama_batch_free(batch);
print("Failed to load model")
exit(1)
}
-
defer {
llama_free_model(model)
}
let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)
var context_params = llama_context_default_params()
-context_params.seed = 1234
context_params.n_ctx = n_kv_req
context_params.n_batch = UInt32(max(n_len, n_parallel))
context_params.n_threads = 8
print("Failed to initialize context")
exit(1)
}
-
defer {
llama_free(context)
}
+var sparams = llama_sampler_chain_default_params()
+
+let smpl = llama_sampler_chain_init(sparams)
+guard smpl != nil else {
+ print("Failed to initialize sampling")
+ exit(1)
+}
+defer {
+ llama_sampler_free(smpl)
+}
+
+llama_sampler_chain_add(smpl, llama_sampler_init_top_k(40));
+llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1));
+llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.4));
+llama_sampler_chain_add(smpl, llama_sampler_init_dist (1234));
+
let n_ctx = llama_n_ctx(context)
print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n")
continue
}
- var n_vocab = llama_n_vocab(model)
- var logits = llama_get_logits_ith(context, i_batch[i])
-
- var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab))
-
- for token_id in 0 ..< n_vocab {
- candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
- }
-
- var candidates_p: llama_token_data_array = .init(
- data: &candidates,
- size: candidates.count,
- sorted: false
- )
-
- let top_k: Int32 = 40
- let top_p: Float = 0.9
- let temp: Float = 0.4
-
- llama_sample_top_k(context, &candidates_p, top_k, 1)
- llama_sample_top_p(context, &candidates_p, top_p, 1)
- llama_sample_temp(context, &candidates_p, temp)
-
- let new_token_id = llama_sample_token(context, &candidates_p)
+ let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
- // const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
+ llama_sampler_accept(smpl, new_token_id)
// is it an end of stream? -> mark the stream as finished
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
let t_main_end = ggml_time_us()
-print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n")
+print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n\n")
-llama_print_timings(context)
+llama_perf_print(UnsafeRawPointer(context), LLAMA_PERF_TYPE_CONTEXT)
+llama_perf_print(UnsafeRawPointer(smpl), LLAMA_PERF_TYPE_SAMPLER_CHAIN)
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
let utf8Count = text.utf8.count
#include "llama.h"
#include <algorithm>
-#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
+ auto sparams = llama_sampler_chain_default_params();
+
+ llama_sampler * smpl = llama_sampler_chain_init(sparams);
+
+ llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sparams.top_k));
+ llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sparams.top_p, params.sparams.min_keep));
+ llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sparams.temp));
+ llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sparams.seed));
+
if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1;
continue;
}
- auto n_vocab = llama_n_vocab(model);
- auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
-
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
-
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
- }
+ const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-
- const int top_k = 40;
- const float top_p = 0.9f;
- const float temp = 0.4f;
-
- llama_sample_top_k(ctx, &candidates_p, top_k, 1);
- llama_sample_top_p(ctx, &candidates_p, top_p, 1);
- llama_sample_temp (ctx, &candidates_p, temp);
-
- const llama_token new_token_id = llama_sample_token(ctx, &candidates_p);
-
- //const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
+ llama_sampler_accept(smpl, new_token_id);
// is it an end of generation? -> mark the stream as finished
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
- llama_print_timings(ctx);
+ LOG_TEE("\n");
+ llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN);
+ llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
fprintf(stderr, "\n");
llama_batch_free(batch);
+ llama_sampler_free(smpl);
llama_free(ctx);
llama_free_model(model);
print_build_info();
- if (params.seed == LLAMA_DEFAULT_SEED) {
- params.seed = time(NULL);
- }
-
- fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);
-
- std::mt19937 rng(params.seed);
+ LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
llama_backend_init();
llama_numa_init(params.numa);
if (notArray) fprintf(stdout, "\n}\n");
}
+ LOG_TEE("\n");
+ llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
+
// clean up
- llama_print_timings(ctx);
llama_batch_free(batch);
llama_free(ctx);
llama_free_model(model);
print_build_info();
- std::mt19937 rng(params.seed);
-
llama_backend_init();
llama_numa_init(params.numa);
return 1;
}
- llama_print_timings(ctx);
+ LOG_TEE("\n");
+ llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
llama_free(ctx);
llama_free_model(model);
-#define LLAMA_API_INTERNAL
-
-#include "grammar-parser.h"
-#include "ggml.h"
-#include "llama.h"
#include "unicode.h"
+#include "llama-grammar.h"
#include <cstdio>
#include <cstdlib>
#include <string>
#include <vector>
-static bool llama_sample_grammar_string(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
- auto decoded = decode_utf8(input_str, {});
- const auto & code_points = decoded.first;
+static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
+ const auto cpts = unicode_cpts_from_utf8(input_str);
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
- llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
+ llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
size_t pos = 0;
- for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
- const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
+ for (const auto & cpt : cpts) {
+ const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
- llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
+ llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
- if (cur_stacks.empty()) {
+ if (stacks_cur.empty()) {
error_pos = pos;
- error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
- cur_stacks = prev_stacks;
+ error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
+ stacks_cur = stacks_prev;
return false;
}
++pos;
}
- for (const auto & stack : cur_stacks) {
+ for (const auto & stack : stacks_cur) {
if (stack.empty()) {
return true;
}
grammar_str = buffer.str();
}
- // Parse the GBNF grammar
- auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
-
- // will be empty (default) if there are parse errors
- if (parsed_grammar.rules.empty()) {
- fprintf(stdout, "%s: failed to parse grammar\n", __func__);
- return 1;
- }
-
- // Ensure that there is a "root" node.
- if (parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()) {
- fprintf(stdout, "%s: grammar does not contain a 'root' symbol\n", __func__);
- return 1;
- }
-
- std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
-
- // Create the LLAMA grammar
- auto grammar = llama_grammar_init(
- grammar_rules.data(),
- grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+ llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
if (grammar == nullptr) {
throw std::runtime_error("Failed to initialize llama_grammar");
}
// Validate the input string against the grammar
size_t error_pos;
std::string error_msg;
- bool is_valid = llama_sample_grammar_string(grammar, input_str, error_pos, error_msg);
+ bool is_valid = llama_grammar_validate(grammar, input_str, error_pos, error_msg);
if (is_valid) {
fprintf(stdout, "Input string is valid according to the grammar.\n");
}
// Clean up
- llama_grammar_free(grammar);
+ llama_grammar_free_impl(grammar);
return 0;
}
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
std::vector<std::vector<float>> result;
- const llama_model * mdl = llama_get_model(ctx);
+ const llama_model * model = llama_get_model(ctx);
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
const std::string input_string = instruction + sentences[i];
- std::vector<llama_token> inputs = llama_tokenize(mdl, input_string, true, false);
+ std::vector<llama_token> inputs = llama_tokenize(model, input_string, true, false);
const int32_t n_toks = inputs.size();
// GritLM seems to have EOS = ""
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
- // inputs.push_back(llama_token_eos(mdl));
+ // inputs.push_back(llama_token_eos(model));
// we want to ignore instruction tokens for mean pooling
- const int32_t n_inst = llama_tokenize(mdl, instruction, true, false).size();
+ const int32_t n_inst = llama_tokenize(model, instruction, true, false).size();
#ifdef GRIT_DEBUG
// debug tokens - should be matching as referenced in the GritLM sample
llama_decode(ctx, batch);
// get embedding dimensions
- uint64_t n_embd = llama_n_embd(mdl);
+ uint64_t n_embd = llama_n_embd(model);
// allocate embedding output
std::vector<float> emb_unorm(n_embd, 0.0f);
return result;
}
-static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) {
+static std::string generate(llama_context * ctx, llama_sampler * smpl, const std::string & prompt, bool stream) {
std::string result;
- const llama_model * mdl = llama_get_model(ctx);
- llama_token eos_token = llama_token_eos(mdl);
+ const llama_model * model = llama_get_model(ctx);
+ llama_token eos_token = llama_token_eos(model);
llama_kv_cache_clear(ctx);
llama_set_embeddings(ctx, false);
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
- std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
+ std::vector<llama_token> inputs = llama_tokenize(model, prompt, false, true);
int32_t i_current_token = 0;
while (true) {
llama_batch_clear(bat);
- auto n_inputs = (int32_t)inputs.size();
- for (int32_t i = 0; i < n_inputs; i++) {
- llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
+ {
+ const int32_t n_inputs = inputs.size();
+
+ for (int32_t i = 0; i < n_inputs; i++) {
+ llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
+ }
}
inputs.clear();
llama_decode(ctx, bat);
- auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
- auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
- auto n_candidates = (int32_t)candidates.size();
- for (int32_t token = 0; token < n_candidates; token++) {
- candidates[token] = llama_token_data{ token, logits[token], 0.0f };
- }
- auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
+ llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
+ llama_sampler_accept(smpl, token);
- llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
if (token == eos_token) {
break;
}
llama_backend_init();
- llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
+ llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
// create generation context
- llama_context * ctx = llama_new_context_with_model(mdl, cparams);
+ llama_context * ctx = llama_new_context_with_model(model, cparams);
+
+ auto sparams = llama_sampler_chain_default_params();
+
+ sparams.no_perf = false;
+
+ llama_sampler * smpl = llama_sampler_chain_init(sparams);
+
+ llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
// ### Embedding/Representation ###
// samples taken from: https://github.com/ContextualAI/gritlm#basic
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
- const int n_embd = llama_n_embd(mdl);
+ const int n_embd = llama_n_embd(model);
const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
{
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
- std::string response = generate(ctx, prompt, true);
+ std::string response = generate(ctx, smpl, prompt, true);
}
+ llama_sampler_free(smpl);
llama_free(ctx);
- llama_free_model(mdl);
+ llama_free_model(model);
llama_backend_free();
return 0;
g_collector.save_imatrix();
- llama_print_timings(ctx);
+ LOG_TEE("\n");
+ llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
llama_free(ctx);
llama_free_model(model);
#include "console.h"
#include "llama.h"
-#include "grammar-parser.h"
#include <cassert>
#include <cinttypes>
static llama_context ** g_ctx;
static llama_model ** g_model;
+static gpt_sampler ** g_smpl;
static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
yaml_dump_string_multiline(logfile, "output", output.c_str());
yaml_dump_vector_int(logfile, "output_tokens", output_tokens);
- llama_dump_timing_info_yaml(logfile, ctx);
+ llama_perf_dump_yaml(logfile, ctx);
fclose(logfile);
}
} else {
console::cleanup();
printf("\n");
- llama_print_timings(*g_ctx);
+ gpt_perf_print(*g_ctx, *g_smpl);
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
_exit(130);
}
int main(int argc, char ** argv) {
gpt_params params;
- llama_sampling_params & sparams = params.sparams;
g_params = ¶ms;
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}
+ auto & sparams = params.sparams;
+
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("infill", "log"));
LOG_TEE("Log start\n");
LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
}
- LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
- LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
-
- if (params.seed == LLAMA_DEFAULT_SEED) {
- params.seed = time(NULL);
- }
-
- LOG_TEE("%s: seed = %u\n", __func__, params.seed);
+ print_build_info();
- std::mt19937 rng(params.seed);
+ LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
LOG("%s: llama backend init\n", __func__);
llama_backend_init();
llama_numa_init(params.numa);
- llama_model * model;
- llama_context * ctx;
+ llama_model * model = nullptr;
+ llama_context * ctx = nullptr;
+ gpt_sampler * smpl = nullptr;
g_model = &model;
g_ctx = &ctx;
+ g_smpl = &smpl;
// load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
}
}
- LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
+ LOG_TEE("sampling: \n%s\n", sparams.print().c_str());
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");
std::vector<llama_token> embd;
- struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
+ smpl = gpt_sampler_init(model, sparams);
while (n_remain != 0 || params.interactive) {
// predict
embd.clear();
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
- const llama_token id = llama_sampling_sample(ctx_sampling, ctx, nullptr);
+ const llama_token id = gpt_sampler_sample(smpl, ctx, -1);
- llama_sampling_accept(ctx_sampling, ctx, id, true);
+ gpt_sampler_accept(smpl, id, true);
- LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
+ // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str());
embd.push_back(id);
// push the prompt in the sampling context in order to apply repetition penalties later
// for the prompt, we don't apply grammar rules
- llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
+ gpt_sampler_accept(smpl, embd_inp[n_consumed], false);
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
// if not currently processing queued inputs;
if ((int) embd_inp.size() <= n_consumed) {
// deal with eot token in infill mode
- if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){
+ if ((gpt_sampler_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){
if (is_interacting && !params.interactive_first) {
// print an eot token
printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
is_interacting = false;
}
// deal with end of generation tokens in interactive mode
- else if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
+ else if (llama_token_is_eog(model, gpt_sampler_last(smpl))) {
LOG("found EOS token\n");
if (params.interactive) {
if (n_past > 0) {
if (is_interacting) {
- llama_sampling_reset(ctx_sampling);
+ gpt_sampler_reset(smpl);
}
is_interacting = false;
}
fflush(stdout);
}
- llama_print_timings(ctx);
+ LOG_TEE("\n");
+ gpt_perf_print(ctx, smpl);
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
llama_free(ctx);
llama_free_model(model);
- llama_sampling_free(ctx_sampling);
+ gpt_sampler_free(smpl);
llama_backend_free();
#ifndef LOG_DISABLE_LOGS
fflush(p_err->fout);
}
- llama_print_timings(ctx);
+ llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
llama_free(ctx);
LOGi("Using %d threads", n_threads);
llama_context_params ctx_params = llama_context_default_params();
- ctx_params.seed = 1234;
- ctx_params.n_ctx = 2048;
+
+ ctx_params.n_ctx = 2048;
ctx_params.n_threads = n_threads;
ctx_params.n_threads_batch = n_threads;
JNIEnv * env,
jobject,
jlong context_pointer,
+ jlong sampling_pointer,
jlong batch_pointer,
jint n_len,
jobject intvar_ncur
) {
const auto context = reinterpret_cast<llama_context *>(context_pointer);
+ const auto sampling = reinterpret_cast<llama_sampler *>(sampling_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
const auto model = llama_get_model(context);
if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
- auto n_vocab = llama_n_vocab(model);
- auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
-
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
-
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
- }
-
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-
// sample the most likely token
- const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
+ const auto new_token_id = llama_sampler_sample(sampling, context, batch->n_tokens - 1);
+
+ llama_sampler_accept(sampling, new_token_id);
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
actor LlamaContext {
private var model: OpaquePointer
private var context: OpaquePointer
+ private var sampling: UnsafeMutablePointer<llama_sampler>
private var batch: llama_batch
private var tokens_list: [llama_token]
var is_done: Bool = false
self.tokens_list = []
self.batch = llama_batch_init(512, 0, 1)
self.temporary_invalid_cchars = []
+ let sparams = llama_sampler_chain_default_params()
+ self.sampling = llama_sampler_chain_init(sparams)
+ llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
+ llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax())
+ llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
}
deinit {
+ llama_sampler_free(sampling)
llama_batch_free(batch)
llama_free(context)
llama_free_model(model)
print("Using \(n_threads) threads")
var ctx_params = llama_context_default_params()
- ctx_params.seed = 1234
ctx_params.n_ctx = 2048
ctx_params.n_threads = Int32(n_threads)
ctx_params.n_threads_batch = Int32(n_threads)
func completion_loop() -> String {
var new_token_id: llama_token = 0
- let n_vocab = llama_n_vocab(model)
- let logits = llama_get_logits_ith(context, batch.n_tokens - 1)
+ new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
- var candidates = Array<llama_token_data>()
- candidates.reserveCapacity(Int(n_vocab))
-
- for token_id in 0..<n_vocab {
- candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
- }
- candidates.withUnsafeMutableBufferPointer() { buffer in
- var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false)
-
- new_token_id = llama_sample_token_greedy(context, &candidates_p)
- }
+ llama_sampler_accept(sampling, new_token_id)
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
print("\n")
return true;
}
-static const char * sample(struct llama_sampling_context * ctx_sampling,
+static const char * sample(struct gpt_sampler * smpl,
struct llama_context * ctx_llama,
int * n_past) {
- const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
- llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
+ const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1);
+ gpt_sampler_accept(smpl, id, true);
static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
ret = "</s>";
LOG_TEE("\n");
- struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
- if (!ctx_sampling) {
+ struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams);
+ if (!smpl) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1);
}
std::string response = "";
for (int i = 0; i < max_tgt_len; i++) {
- const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
+ const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
response += tmp;
if (strcmp(tmp, "</s>") == 0) break;
if (strstr(tmp, "###")) break; // Yi-VL behavior
fflush(stdout);
}
- llama_sampling_free(ctx_sampling);
+ gpt_sampler_free(smpl);
printf("\n");
}
// process the prompt
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
- llama_print_timings(ctx_llava->ctx_llama);
+ llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT);
llava_image_embed_free(image_embed);
ctx_llava->model = NULL;
llava_free(ctx_llava);
// process the prompt
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
- llama_print_timings(ctx_llava->ctx_llama);
+ llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT);
llava_image_embed_free(image_embed);
ctx_llava->model = NULL;
llava_free(ctx_llava);
LOG_TEE("%s: image token past: %d\n", __func__, n_past);
}
-static const char * sample(struct llama_sampling_context * ctx_sampling,
+static const char * sample(struct gpt_sampler * smpl,
struct llama_context * ctx_llama,
int * n_past) {
- const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
- llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
+ const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1);
+ gpt_sampler_accept(smpl, id, true);
static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
ret = "</s>";
return ctx_llava;
}
-static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
+static struct gpt_sampler * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
std::string user_prompt = prompt;
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
if (!is_first) {
LOG_TEE("\n");
- struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
- return ctx_sampling;
+ struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams);
+ return smpl;
}
-static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling_context * ctx_sampling, int &n_past){
+static const char * llama_loop(struct llava_context * ctx_llava,struct gpt_sampler * smpl, int &n_past){
- const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
+ const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
return tmp;
}
if (!params.prompt.empty()) {
LOG_TEE("<user>%s\n", params.prompt.c_str());
LOG_TEE("<assistant>");
- auto ctx_sampling = llama_init(ctx_llava, ¶ms, params.prompt.c_str(), n_past, true);
+ auto smpl = llama_init(ctx_llava, ¶ms, params.prompt.c_str(), n_past, true);
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
std::string response = "";
bool have_tmp = false;
for (int i = 0; i < max_tgt_len; i++) {
- auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past);
+ auto tmp = llama_loop(ctx_llava, smpl, n_past);
response += tmp;
if (strcmp(tmp, "</s>") == 0){
if(!have_tmp)continue;
fflush(stdout);
}
- llama_sampling_free(ctx_sampling);
+ gpt_sampler_free(smpl);
}else {
while (true) {
LOG_TEE("<user>");
std::string prompt;
std::getline(std::cin, prompt);
LOG_TEE("<assistant>");
- auto ctx_sampling = llama_init(ctx_llava, ¶ms, prompt, n_past, true);
+ auto smpl = llama_init(ctx_llava, ¶ms, prompt, n_past, true);
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
std::string response = "";
for (int i = 0; i < max_tgt_len; i++) {
- auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past);
+ auto tmp = llama_loop(ctx_llava, smpl, n_past);
response += tmp;
if (strcmp(tmp, "</s>") == 0) break;
if (strstr(tmp, "###")) break; // Yi-VL behavior
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
fflush(stdout);
}
- llama_sampling_free(ctx_sampling);
+ gpt_sampler_free(smpl);
}
}
printf("\n");
- llama_print_timings(ctx_llava->ctx_llama);
+ llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT);
ctx_llava->model = NULL;
llava_free(ctx_llava);
#include "common.h"
#include "llama.h"
-#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
// target model sampling context
- struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
+ struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams);
// verification n-grams
std::vector<ngram_data> ngrams_cur(G);
// sample first token
{
- id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
+ id = gpt_sampler_sample(smpl, ctx, 0);
- llama_sampling_accept(ctx_sampling, ctx, id, true);
+ gpt_sampler_accept(smpl, id, true);
{
const std::string token_str = llama_token_to_piece(ctx, id);
}
// sample the next token
- id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
+ id = gpt_sampler_sample(smpl, ctx, i_batch);
- llama_sampling_accept(ctx_sampling, ctx, id, true);
+ gpt_sampler_accept(smpl, id, true);
// print
{
if (v == 0) {
// sample from the last level
for (int i = 0; i < W; i++) {
- tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
+ tokens_j[N - 2][i] = gpt_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
}
} else {
for (int i = 0; i < W; i++) {
LOG_TEE("n_predict = %d\n", n_predict);
LOG_TEE("n_accept = %d\n", n_accept);
- llama_print_timings(ctx);
+ LOG_TEE("\n");
+ gpt_perf_print(ctx, smpl);
+
+ gpt_sampler_free(smpl);
llama_kv_cache_view_free(&kvc_view);
- llama_sampling_free(ctx_sampling);
llama_batch_free(batch);
#include "common.h"
#include "ngram-cache.h"
-#include <cmath>
#include <cstdint>
#include <cstdio>
#include <fstream>
#include <string>
#include <vector>
-#include <unordered_map>
int main(int argc, char ** argv){
gpt_params params;
bool has_eos = false;
- struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
+ struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams);
std::vector<llama_token> draft;
int i_dft = 0;
while (true) {
// sample from the target model
- llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
+ llama_token id = gpt_sampler_sample(smpl, ctx, i_dft);
- llama_sampling_accept(ctx_sampling, ctx, id, true);
+ gpt_sampler_accept(smpl, id, true);
const std::string token_str = llama_token_to_piece(ctx, id);
LOG_TEE("n_accept = %d\n", n_accept);
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
- LOG_TEE("\ntarget:\n");
- llama_print_timings(ctx);
+ LOG_TEE("\ntarget:\n\n");
+ llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN);
+ llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
+
+ gpt_sampler_free(smpl);
- llama_sampling_free(ctx_sampling);
llama_batch_free(batch_tgt);
llama_free(ctx);
static llama_context ** g_ctx;
static llama_model ** g_model;
+static gpt_sampler ** g_smpl;
static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
yaml_dump_string_multiline(logfile, "output", output.c_str());
yaml_dump_vector_int(logfile, "output_tokens", output_tokens);
- llama_dump_timing_info_yaml(logfile, ctx);
+ llama_perf_dump_yaml(logfile, ctx);
fclose(logfile);
}
} else {
console::cleanup();
printf("\n");
- llama_print_timings(*g_ctx);
+ gpt_perf_print(*g_ctx, *g_smpl);
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
_exit(130);
}
static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
llama_chat_msg new_msg{role, content};
- auto formatted = llama_chat_format_single(
- model, g_params->chat_template, chat_msgs, new_msg, role == "user");
+ auto formatted = llama_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
chat_msgs.push_back({role, content});
LOG("formatted: %s\n", formatted.c_str());
return formatted;
return 1;
}
- llama_sampling_params & sparams = params.sparams;
+ auto & sparams = params.sparams;
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("main", "log"));
LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
}
- LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
- LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
+ print_build_info();
- if (params.seed == LLAMA_DEFAULT_SEED) {
- params.seed = time(NULL);
- }
-
- LOG_TEE("%s: seed = %u\n", __func__, params.seed);
-
- std::mt19937 rng(params.seed);
+ LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
LOG("%s: llama backend init\n", __func__);
llama_backend_init();
llama_numa_init(params.numa);
- llama_model * model;
- llama_context * ctx;
- llama_context * ctx_guidance = NULL;
+ llama_model * model = nullptr;
+ llama_context * ctx = nullptr;
+ gpt_sampler * smpl = nullptr;
+
std::vector<llama_chat_msg> chat_msgs;
+
g_model = &model;
g_ctx = &ctx;
+ g_smpl = &smpl;
// load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
model = llama_init.model;
ctx = llama_init.context;
- if (sparams.cfg_scale > 1.f) {
- struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
- ctx_guidance = llama_new_context_with_model(model, lparams);
- }
if (model == NULL) {
LOG_TEE("%s: error: unable to load model\n", __func__);
}
llama_attach_threadpool(ctx, threadpool, threadpool_batch);
- if (ctx_guidance) {
- llama_attach_threadpool(ctx_guidance, threadpool, threadpool_batch);
- }
const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
}
// Tokenize negative prompt
- std::vector<llama_token> guidance_inp;
- int guidance_offset = 0;
- int original_prompt_len = 0;
- if (ctx_guidance) {
- LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
-
- guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true, true);
- LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
-
- std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
- LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
-
- original_prompt_len = original_inp.size();
- guidance_offset = (int)guidance_inp.size() - original_prompt_len;
- LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
- LOG("guidance_offset: %s", log_tostr(guidance_offset));
- }
-
if ((int) embd_inp.size() > n_ctx - 4) {
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
return 1;
LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
}
- if (ctx_guidance) {
- LOG_TEE("\n");
- LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str());
- LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
- for (int i = 0; i < (int) guidance_inp.size(); i++) {
- LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str());
- }
- }
-
if (params.n_keep > add_bos) {
LOG_TEE("%s: static prompt based on n_keep: '", __func__);
for (int i = 0; i < params.n_keep; i++) {
}
}
}
- LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
- LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
+
+ smpl = gpt_sampler_init(model, sparams);
+ if (!smpl) {
+ fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
+ exit(1);
+ }
+
+ LOG_TEE("sampling params: \n%s\n", sparams.print().c_str());
+ LOG_TEE(" sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str());
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
// group-attention state
int n_remain = params.n_predict;
int n_consumed = 0;
int n_session_consumed = 0;
- int n_past_guidance = 0;
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
display = params.display_prompt;
std::vector<llama_token> embd;
- std::vector<llama_token> embd_guidance;
// tokenized antiprompts
std::vector<std::vector<llama_token>> antiprompt_ids;
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
}
- struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
- if (!ctx_sampling) {
- fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
- exit(1);
- }
-
if (llama_model_has_encoder(model)) {
int enc_input_size = embd_inp.size();
llama_token * enc_input_buf = embd_inp.data();
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
- if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) >= n_ctx) {
+ if (n_past + (int) embd.size() >= n_ctx) {
if (params.n_predict == -2) {
LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break;
n_past -= n_discard;
- if (ctx_guidance) {
- n_past_guidance -= n_discard;
- }
-
- LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
+ LOG("after swap: n_past = %d\n", n_past);
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
}
}
- // evaluate tokens in batches
- // embd is typically prepared beforehand to fit within a batch, but not always
- if (ctx_guidance) {
- int input_size = 0;
- llama_token * input_buf = NULL;
-
- if (n_past_guidance < (int) guidance_inp.size()) {
- // Guidance context should have the same data with these modifications:
- //
- // * Replace the initial prompt
- // * Shift everything by guidance_offset
- embd_guidance = guidance_inp;
- if (embd.begin() + original_prompt_len < embd.end()) {
- embd_guidance.insert(
- embd_guidance.end(),
- embd.begin() + original_prompt_len,
- embd.end()
- );
- }
-
- input_buf = embd_guidance.data();
- input_size = embd_guidance.size();
-
- LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
- } else {
- input_buf = embd.data();
- input_size = embd.size();
- }
-
- for (int i = 0; i < input_size; i += params.n_batch) {
- int n_eval = std::min(input_size - i, params.n_batch);
- if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) {
- LOG_TEE("%s : failed to eval\n", __func__);
- return 1;
- }
-
- n_past_guidance += n_eval;
- }
- }
-
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
int n_eval = (int) embd.size() - i;
if (n_eval > params.n_batch) {
}
embd.clear();
- embd_guidance.clear();
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// optionally save the session on first sample (for faster prompt loading next time)
LOG("saved session to %s\n", path_session.c_str());
}
- const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
+ const llama_token id = gpt_sampler_sample(smpl, ctx, -1);
- llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);
+ gpt_sampler_accept(smpl, id, /* apply_grammar= */ true);
- LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
+ // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str());
embd.push_back(id);
// push the prompt in the sampling context in order to apply repetition penalties later
// for the prompt, we don't apply grammar rules
- llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], /* apply_grammar= */ false);
+ gpt_sampler_accept(smpl, embd_inp[n_consumed], /* apply_grammar= */ false);
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
// check for reverse prompt in the last n_prev tokens
if (!params.antiprompt.empty()) {
const int n_prev = 32;
- const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev);
+ const std::string last_output = gpt_sampler_prev_str(smpl, ctx, n_prev);
is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
}
// check for reverse prompt using special tokens
- llama_token last_token = llama_sampling_last(ctx_sampling);
+ llama_token last_token = gpt_sampler_last(smpl);
for (std::vector<llama_token> ids : antiprompt_ids) {
if (ids.size() == 1 && last_token == ids[0]) {
if (params.interactive) {
}
// deal with end of generation tokens in interactive mode
- if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
+ if (llama_token_is_eog(model, gpt_sampler_last(smpl))) {
LOG("found an EOG token\n");
if (params.interactive) {
// if current token is not EOG, we add it to current assistant message
if (params.conversation) {
- auto id = llama_sampling_last(ctx_sampling);
+ const auto id = gpt_sampler_last(smpl);
assistant_ss << llama_token_to_piece(ctx, id, false);
}
if (n_past > 0) {
if (is_interacting) {
- llama_sampling_reset(ctx_sampling);
+ gpt_sampler_reset(smpl);
}
is_interacting = false;
}
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
}
- llama_print_timings(ctx);
+ LOG_TEE("\n");
+ gpt_perf_print(ctx, smpl);
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
- if (ctx_guidance) { llama_free(ctx_guidance); }
+ gpt_sampler_free(smpl);
+
llama_free(ctx);
llama_free_model(model);
- llama_sampling_free(ctx_sampling);
llama_backend_free();
ggml_threadpool_free(threadpool);
struct client {
~client() {
- if (ctx_sampling) {
- llama_sampling_free(ctx_sampling);
+ if (smpl) {
+ gpt_sampler_free(smpl);
}
}
std::string prompt;
std::string response;
- struct llama_sampling_context * ctx_sampling = nullptr;
+ struct gpt_sampler * smpl = nullptr;
};
static void print_date_time() {
for (size_t i = 0; i < clients.size(); ++i) {
auto & client = clients[i];
client.id = i;
- client.ctx_sampling = llama_sampling_init(params.sparams);
+ client.smpl = gpt_sampler_init(model, params.sparams);
}
std::vector<llama_token> tokens_system;
client.prompt = client.input + "\nAssistant:";
client.response = "";
- llama_sampling_reset(client.ctx_sampling);
+ gpt_sampler_reset(client.smpl);
// do not prepend BOS because we have a system prompt!
std::vector<llama_token> tokens_prompt;
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
- const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
+ const llama_token id = gpt_sampler_sample(client.smpl, ctx, client.i_batch - i);
- llama_sampling_accept(client.ctx_sampling, ctx, id, true);
+ gpt_sampler_accept(client.smpl, id, true);
if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients
}
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
- llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
+ llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
const auto t_main_end = ggml_time_us();
LOG_TEE("\n");
- llama_print_timings(ctx);
+ // TODO: print sampling/grammar timings for all clients
+ llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
llama_batch_free(batch);
return 1;
}
- srand(params.seed == LLAMA_DEFAULT_SEED ? time(NULL) : params.seed);
-
int n_junk = params.n_junk;
int n_keep = params.n_keep;
int n_grp = params.grp_attn_n;
GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
-
if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1;
}
+ auto sparams = llama_sampler_chain_default_params();
+
+ llama_sampler * smpl = llama_sampler_chain_init(sparams);
+
+ llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
+
// tokenize the prompt
std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
while (n_cur <= n_len) {
// sample the next token
{
- auto n_vocab = llama_n_vocab(model);
- auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
+ const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
-
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
- }
-
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-
- // sample the most likely token
- const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
+ llama_sampler_accept(smpl, new_token_id);
// is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
- llama_print_timings(ctx);
+ LOG_TEE("\n");
+ llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
fprintf(stderr, "\n");
+ llama_sampler_free(smpl);
+
llama_batch_free(batch);
llama_free(ctx);
fprintf(logfile, "ppl_value: %f\n", results.ppl_value);
yaml_dump_vector_float(logfile, "probs", results.probs);
- llama_dump_timing_info_yaml(logfile, ctx);
+ llama_perf_dump_yaml(logfile, ctx);
fclose(logfile);
}
print_build_info();
- if (params.seed == LLAMA_DEFAULT_SEED) {
- params.seed = time(NULL);
- }
-
- fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);
-
- std::mt19937 rng(params.seed);
+ LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
llama_backend_init();
llama_numa_init(params.numa);
results = perplexity(ctx, params, n_ctx);
}
- llama_print_timings(ctx);
+ LOG_TEE("\n");
+ llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
write_logfile(ctx, params, model, results);
llama_free(ctx);
-#define LLAMA_API_INTERNAL
#include "common.h"
#include "ggml.h"
#include "llama.h"
+#include "llama-impl.h"
#include <algorithm>
#include <cassert>
}
auto cparams = llama_context_default_params();
- cparams.n_ctx = 256;
- cparams.seed = 1;
+ cparams.n_ctx = 256;
ctx = llama_new_context_with_model(model, cparams);
}
}
+ LOG_TEE("\n");
+ llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
+
// clean up
llama_batch_free(query_batch);
- llama_print_timings(ctx);
llama_free(ctx);
llama_free_model(model);
llama_backend_free();
#include <vector>
#include <cstdio>
-#include <chrono>
int main(int argc, char ** argv) {
gpt_params params;
params.prompt = "The quick brown fox";
+ params.sparams.seed = 1234;
if (!gpt_params_parse(argc, argv, params)) {
gpt_params_print_usage(argc, argv, params);
return 1;
}
+ auto sparams = llama_sampler_chain_default_params();
+
+ llama_sampler * smpl = llama_sampler_chain_init(sparams);
+
+ llama_sampler_chain_add(smpl, llama_sampler_init_softmax());
+ llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed));
+
// tokenize prompt
auto tokens = llama_tokenize(ctx, params.prompt, true);
printf("\nfirst run: %s", params.prompt.c_str());
for (auto i = 0; i < params.n_predict; i++) {
- auto * logits = llama_get_logits(ctx);
- auto n_vocab = llama_n_vocab(model);
-
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
- }
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
- auto next_token = llama_sample_token(ctx, &candidates_p);
+ auto next_token = llama_sampler_sample(smpl, ctx, -1);
auto next_token_str = llama_token_to_piece(ctx, next_token);
+ llama_sampler_accept(smpl, next_token);
+
printf("%s", next_token_str.c_str());
result0 += next_token_str;
// make new context
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
+ llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
+
+ llama_sampler_chain_add(smpl2, llama_sampler_init_softmax());
+ llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed));
+
printf("\nsecond run: %s", params.prompt.c_str());
// load state (rng, logits, embedding and kv_cache) from file
// second run
for (auto i = 0; i < params.n_predict; i++) {
- auto * logits = llama_get_logits(ctx2);
- auto n_vocab = llama_n_vocab(model);
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
- }
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
- auto next_token = llama_sample_token(ctx2, &candidates_p);
+ auto next_token = llama_sampler_sample(smpl2, ctx2, -1);
auto next_token_str = llama_token_to_piece(ctx2, next_token);
+ llama_sampler_accept(smpl2, next_token);
+
printf("%s", next_token_str.c_str());
result1 += next_token_str;
}
// make new context
- auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
+ auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
+
+ llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
+
+ llama_sampler_chain_add(smpl3, llama_sampler_init_softmax());
+ llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed));
printf("\nsingle seq run: %s", params.prompt.c_str());
// third run with seq 1 instead of 0
for (auto i = 0; i < params.n_predict; i++) {
- auto * logits = llama_get_logits(ctx3);
- auto n_vocab = llama_n_vocab(model);
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
- }
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
- auto next_token = llama_sample_token(ctx3, &candidates_p);
+ auto next_token = llama_sampler_sample(smpl3, ctx3, -1);
auto next_token_str = llama_token_to_piece(ctx3, next_token);
+ llama_sampler_accept(smpl3, next_token);
+
printf("%s", next_token_str.c_str());
result2 += next_token_str;
printf("\n");
+ llama_sampler_free(smpl);
+ llama_sampler_free(smpl2);
+ llama_sampler_free(smpl3);
+
llama_free(ctx3);
llama_free_model(model);
`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
- `penalty_prompt`: This will replace the `prompt` for the purpose of the penalty evaluation. Can be either `null`, a string or an array of numbers representing tokens. Default: `null`, which is to use the original `prompt`.
-
`mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0.
`mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0`
"stopping_word": ""
},
"penalize_nl": true,
- "penalty_prompt_tokens": [],
"presence_penalty": 0.0,
"prompt": "Say hello to llama.cpp",
"repeat_last_n": 64,
"tfs_z": 1.0,
"top_k": 40,
"top_p": 0.949999988079071,
- "typical_p": 1.0,
- "use_penalty_prompt_tokens": false
+ "typical_p": 1.0
}
]
```
#include "common.h"
#include "json-schema-to-grammar.h"
#include "llama.h"
-#include "grammar-parser.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
std::string stopping_word;
// sampling
- llama_token sampled;
- struct llama_sampling_params sparams;
- llama_sampling_context * ctx_sampling = nullptr;
json json_schema;
+ struct gpt_sampler_params sparams;
+ struct gpt_sampler * smpl = nullptr;
+
+ llama_token sampled;
+
int32_t ga_i = 0; // group-attention state
int32_t ga_n = 1; // group-attention factor
int32_t ga_w = 512; // group-attention width
// Clear any sampling context
for (server_slot & slot : slots) {
- if (slot.ctx_sampling != nullptr) {
- llama_sampling_free(slot.ctx_sampling);
+ if (slot.smpl != nullptr) {
+ gpt_sampler_free(slot.smpl);
}
}
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
slot_params default_params;
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
- llama_sampling_params default_sparams = params.sparams;
- auto & data = task.data;
+ auto default_sparams = params.sparams;
+ const auto & data = task.data;
if (data.count("__oaicompat") != 0) {
slot.oaicompat = true;
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
- slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
+ slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
return false;
- } else if (data.contains("json_schema") && !data.contains("grammar")) {
+ }
+ if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
slot.sparams.grammar = json_schema_to_grammar(schema);
}
}
- // penalize user-provided tokens
- {
- slot.sparams.penalty_prompt_tokens.clear();
- slot.sparams.use_penalty_prompt_tokens = false;
-
- const auto & penalty_prompt = data.find("penalty_prompt");
-
- if (penalty_prompt != data.end()) {
- if (penalty_prompt->is_string()) {
- const auto penalty_prompt_string = penalty_prompt->get<std::string>();
- slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false);
-
- if (slot.params.n_predict > 0) {
- slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
- }
- slot.sparams.use_penalty_prompt_tokens = true;
-
- LOG_VERBOSE("penalty_prompt_tokens", {
- {"id_slot", slot.id},
- {"tokens", slot.sparams.penalty_prompt_tokens},
- });
- }
- else if (penalty_prompt->is_array()) {
- const auto n_tokens = penalty_prompt->size();
- slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
-
- const int n_vocab = llama_n_vocab(model);
- for (const auto & penalty_token : *penalty_prompt) {
- if (penalty_token.is_number_integer()) {
- const auto tok = penalty_token.get<llama_token>();
- if (tok >= 0 && tok < n_vocab) {
- slot.sparams.penalty_prompt_tokens.push_back(tok);
- }
- }
- }
- slot.sparams.use_penalty_prompt_tokens = true;
-
- LOG_VERBOSE("penalty_prompt_tokens", {
- {"id_slot", slot.id},
- {"tokens", slot.sparams.penalty_prompt_tokens},
- });
- }
- }
- }
-
{
slot.sparams.logit_bias.clear();
if (json_value(data, "ignore_eos", false) && has_eos_token) {
- slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
+ slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
}
const auto & logit_bias = data.find("logit_bias");
if (el[0].is_number_integer()) {
llama_token tok = el[0].get<llama_token>();
if (tok >= 0 && tok < n_vocab) {
- slot.sparams.logit_bias[tok] = bias;
+ slot.sparams.logit_bias.push_back({tok, bias});
}
} else if (el[0].is_string()) {
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
for (auto tok : toks) {
- slot.sparams.logit_bias[tok] = bias;
+ slot.sparams.logit_bias.push_back({tok, bias});
}
}
}
}
{
- const auto & samplers_sequence = data.find("samplers");
- if (samplers_sequence != data.end() && samplers_sequence->is_array()) {
+ const auto & samplers = data.find("samplers");
+ if (samplers != data.end() && samplers->is_array()) {
std::vector<std::string> sampler_names;
- for (const auto & sampler_name : *samplers_sequence) {
- if (sampler_name.is_string()) {
- sampler_names.emplace_back(sampler_name);
+ for (const auto & name : *samplers) {
+ if (name.is_string()) {
+ sampler_names.emplace_back(name);
}
}
- slot.sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, false);
+ slot.sparams.samplers = gpt_sampler_types_from_names(sampler_names, false);
} else {
- slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
+ slot.sparams.samplers = default_sparams.samplers;
}
}
{
- if (slot.ctx_sampling != nullptr) {
- llama_sampling_free(slot.ctx_sampling);
+ if (slot.smpl != nullptr) {
+ gpt_sampler_free(slot.smpl);
}
- slot.ctx_sampling = llama_sampling_init(slot.sparams);
- if (slot.ctx_sampling == nullptr) {
+
+ slot.smpl = gpt_sampler_init(model, slot.sparams);
+ if (slot.smpl == nullptr) {
// for now, the only error that may happen here is invalid grammar
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
return false;
slot.generated_text += token_str;
slot.has_next_token = true;
- if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
- // we can change penalty_prompt_tokens because it is always created from scratch each request
- slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
- }
-
// check if there is incomplete UTF-8 character at the end
bool incomplete = false;
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
}
json get_formated_generation(const server_slot & slot) const {
- const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
- const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
-
- std::vector<std::string> samplers_sequence;
- samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
- for (const auto & sampler_type : slot.sparams.samplers_sequence) {
- samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type));
+ std::vector<std::string> samplers;
+ samplers.reserve(slot.sparams.samplers.size());
+ for (const auto & sampler : slot.sparams.samplers) {
+ samplers.emplace_back(gpt_sampler_type_to_str(sampler));
}
return json {
{"top_p", slot.sparams.top_p},
{"min_p", slot.sparams.min_p},
{"tfs_z", slot.sparams.tfs_z},
- {"typical_p", slot.sparams.typical_p},
+ {"typical_p", slot.sparams.typ_p},
{"repeat_last_n", slot.sparams.penalty_last_n},
{"repeat_penalty", slot.sparams.penalty_repeat},
{"presence_penalty", slot.sparams.penalty_present},
{"frequency_penalty", slot.sparams.penalty_freq},
- {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
- {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
{"mirostat", slot.sparams.mirostat},
{"mirostat_tau", slot.sparams.mirostat_tau},
{"mirostat_eta", slot.sparams.mirostat_eta},
{"max_tokens", slot.params.n_predict}, // User configured n_predict
{"n_keep", slot.params.n_keep},
{"n_discard", slot.params.n_discard},
- {"ignore_eos", ignore_eos},
+ {"ignore_eos", slot.sparams.ignore_eos},
{"stream", slot.params.stream},
- {"logit_bias", slot.sparams.logit_bias},
+ //{"logit_bias", slot.sparams.logit_bias},
{"n_probs", slot.sparams.n_probs},
{"min_keep", slot.sparams.min_keep},
{"grammar", slot.sparams.grammar},
- {"samplers", samplers_sequence}
+ {"samplers", samplers},
};
}
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
}
- llama_sampling_reset(slot.ctx_sampling);
+ gpt_sampler_reset(slot.smpl);
if (!slot.params.cache_prompt) {
slot.n_past_se = 0;
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
- llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
+ gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
}
}
}
slot.n_past_se = 0;
slot.ga_i = 0;
// TODO: is the system prompt ever in the sampling context?
- llama_sampling_reset(slot.ctx_sampling);
+ gpt_sampler_reset(slot.smpl);
}
// remove the non-common part from the cache
slot.release();
slot.i_batch = -1;
continue; // continue loop of slots
- } else {
- // prompt evaluated for next-token prediction
- slot.state = SLOT_STATE_GENERATING;
}
+
+ // prompt evaluated for next-token prediction
+ slot.state = SLOT_STATE_GENERATING;
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
}
completion_token_output result;
- const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
+ const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
- llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
+ gpt_sampler_accept(slot.smpl, id, true);
slot.n_decoded += 1;
if (slot.n_decoded == 1) {
metrics.on_prompt_eval(slot);
}
- llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
result.tok = id;
- const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
- if (n_probs > 0) {
- const size_t n_valid = slot.ctx_sampling->n_valid;
+ const auto * cur_p = gpt_sampler_get_candidates(slot.smpl);
- // Make sure at least n_probs top tokens are at the front of the vector:
- if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
- llama_sample_top_k(ctx, &cur_p, n_probs, 0);
- }
-
- if (slot.sparams.temp == 0.0f) {
- // With greedy sampling the probabilities have possibly not been calculated.
- for (size_t i = 0; i < n_probs; ++i) {
- result.probs.push_back({
- cur_p.data[i].id,
- i == 0 ? 1.0f : 0.0f
- });
- }
- } else {
- for (size_t i = 0; i < n_probs; ++i) {
- result.probs.push_back({
- cur_p.data[i].id,
- i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
- });
- }
- }
+ for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
+ result.probs.push_back({
+ cur_p->data[i].id,
+ i >= cur_p->size ? 0.0f : cur_p->data[i].p,
+ });
}
if (!process_token(result, slot)) {
return 1;
}
+ auto sparams = llama_sampler_chain_default_params();
+
+ sparams.no_perf = false;
+
+ llama_sampler * smpl = llama_sampler_chain_init(sparams);
+
+ llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
+
// tokenize the prompt
std::vector<llama_token> tokens_list;
while (n_cur <= n_predict) {
// sample the next token
{
- auto n_vocab = llama_n_vocab(model);
- auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
-
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
+ const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
- candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
- }
-
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-
- // sample the most likely token
- const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
+ llama_sampler_accept(smpl, new_token_id);
// is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
- llama_print_timings(ctx);
+ LOG_TEE("\n");
+ llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN);
+ llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
fprintf(stderr, "\n");
llama_batch_free(batch);
-
+ llama_sampler_free(smpl);
llama_free(ctx);
llama_free_model(model);
std::vector<llama_token> tokens;
std::vector<std::vector<llama_token_data>> dists;
- struct llama_sampling_context * ctx_sampling;
+ struct gpt_sampler * smpl = nullptr;
};
int main(int argc, char ** argv) {
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
const float p_split = params.p_split;
- if (params.seed == LLAMA_DEFAULT_SEED) {
- params.seed = time(NULL);
- }
- std::default_random_engine rng(params.seed);
+ std::default_random_engine rng(params.sparams.seed);
std::uniform_real_distribution<> u_dist;
#ifndef LOG_DISABLE_LOGS
// used to determine end of generation
bool has_eos = false;
- // target model sampling context
- struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
+ // target model sampling context (reuse the llama_context's sampling instance)
+ struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams);
+
+ struct llama_sampler * softmax = llama_sampler_init_softmax();
// draft sequence data
std::vector<seq_draft> drafts(n_seq_dft);
- params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
- if (params.sparams.temp == 0) {
- params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
- }
-
for (int s = 0; s < n_seq_dft; ++s) {
- drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
+ // allocate gpt_sampler for each draft sequence
+ drafts[s].smpl = gpt_sampler_init(model_dft, params.sparams);
}
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
bool accept = false;
if (params.sparams.temp > 0) {
// stochastic verification
+ gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
- llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
- llama_sample_softmax(ctx_tgt, &dist_tgt);
- float p_tgt = 0, p_dft = 0;
+ auto & dist_tgt = *gpt_sampler_get_candidates(smpl);
- // GGML_ASSERT(dist_tgt.size() == dist_dft.size());
+ float p_tgt = 0.0f;
+ float p_dft = 0.0f;
while (active_seqs.size() > 0) {
// randomly select a sequence to verify from active sequences
}
continue;
}
+
LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
float r = u_dist(rng);
- llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true };
+ llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true };
+
+ //GGML_ASSERT(dist_tgt.size <= dist_dft.size);
+
// acquire the token probabilities assigned by the draft and target models
for (size_t i = 0; i < dist_tgt.size; i++) {
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
accept = true;
token_id = drafts[s].tokens[i_dft];
token_str = llama_token_to_piece(ctx_tgt, token_id);
- llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
+ gpt_sampler_accept(smpl, token_id, true);
LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
break;
// calculate residual probability
GGML_ASSERT(dist_tgt.sorted);
GGML_ASSERT(dist_dft.sorted);
- float sum_probs = 0.0f;
// sort dist by id
std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
return a.id < b.id;
});
+ float sum_probs = 0.0f;
+
for (size_t i = 0; i < dist_tgt.size; i++) {
- dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
+ if (i < dist_dft.size) {
+ dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
+ } else {
+ dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p);
+ }
+
sum_probs += dist_tgt.data[i].p;
}
+
for (size_t i = 0; i < dist_tgt.size; i++) {
dist_tgt.data[i].p /= sum_probs;
}
// all drafted tokens were rejected
// sample from the target model
LOG("all drafted tokens were rejected, sampling from residual distribution\n");
- token_id = llama_sample_token(ctx_tgt, &dist_tgt);
- llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
+ std::vector<float> probs(dist_tgt.size);
+ for (size_t i = 0; i < dist_tgt.size; ++i) {
+ probs[i] = dist_tgt.data[i].p;
+ }
+
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
+
+ const int idx = dist(rng);
+
+ token_id = dist_tgt.data[idx].id;
+ gpt_sampler_accept(smpl, token_id, true);
token_str = llama_token_to_piece(ctx_tgt, token_id);
}
-
} else {
// greedy verification
// sample from the target model
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
- token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
+ token_id = gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
- llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
+ gpt_sampler_accept(smpl, token_id, true);
- //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
+ //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, smpl->prev).c_str());
token_str = llama_token_to_piece(ctx_tgt, token_id);
break;
}
- llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling);
+ if (drafts[0].smpl) {
+ gpt_sampler_free(drafts[0].smpl);
+ }
+ drafts[0].smpl = gpt_sampler_clone(smpl);
int n_seq_cur = 1;
int n_past_cur = n_past_dft;
continue;
}
- llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
+ gpt_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
- const auto & cur_p = drafts[s].ctx_sampling->cur;
+ const auto * cur_p = gpt_sampler_get_candidates(drafts[s].smpl);
- for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
+ for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
- k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
+ k, s, i, cur_p->data[k].id, cur_p->data[k].p, llama_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
std::vector<int> sa(1, s);
// attempt to split the branch if the probability is high enough
for (int f = 1; f < 8; ++f) {
- if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
+ if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) {
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
- llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling);
+ if (drafts[n_seq_cur].smpl) {
+ gpt_sampler_free(drafts[n_seq_cur].smpl);
+ }
+ drafts[n_seq_cur].smpl = gpt_sampler_clone(drafts[s].smpl);
sa.push_back(n_seq_cur);
// add drafted token for each sequence
for (int is = 0; is < (int) sa.size(); ++is) {
- const llama_token id = cur_p[is].id;
+ const llama_token id = cur_p->data[is].id;
const int s = sa[is];
- llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
+ gpt_sampler_accept(drafts[s].smpl, id, true);
drafts[s].tokens.push_back(id);
// save cur_p.data into drafts[s].dists
- drafts[s].dists.push_back(cur_p);
+ drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
// add unique drafted tokens to the target batch
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
LOG_TEE("n_accept = %d\n", n_accept);
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
- LOG_TEE("\ndraft:\n");
- llama_print_timings(ctx_dft);
+ LOG_TEE("\ndraft:\n\n");
+ // TODO: print sampling/grammar timings for all drafts
+ llama_perf_print(ctx_dft, LLAMA_PERF_TYPE_CONTEXT);
- LOG_TEE("\ntarget:\n");
- llama_print_timings(ctx_tgt);
+ LOG_TEE("\ntarget:\n\n");
+ gpt_perf_print(ctx_tgt, smpl);
- llama_sampling_free(ctx_sampling);
+ gpt_sampler_free(smpl);
for (int s = 0; s < n_seq_dft; ++s) {
- llama_sampling_free(drafts[s].ctx_sampling);
+ gpt_sampler_free(drafts[s].smpl);
}
+ llama_sampler_free(softmax);
llama_batch_free(batch_dft);
llama_free(ctx_tgt);
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
+// TODO: use everywhere in the implementation
+#define LLAMA_TOKEN_NULL -1
+
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
-#define LLAMA_SESSION_VERSION 8
+#define LLAMA_SESSION_VERSION 9
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
#define LLAMA_STATE_SEQ_VERSION 2
// TODO: show sample usage
//
+ // struct llama_vocab; // TODO: add in the future
struct llama_model;
struct llama_context;
+ struct llama_sampler;
typedef int32_t llama_pos;
typedef int32_t llama_token;
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
};
+ // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
typedef struct llama_token_data {
llama_token id; // token id
float logit; // log-odds of the token
} llama_token_data;
typedef struct llama_token_data_array {
+ // TODO: consider SoA
llama_token_data * data;
size_t size;
+ int64_t selected; // this is the index in the data array (i.e. not the token id)
bool sorted;
} llama_token_data_array;
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
// https://github.com/ggerganov/llama.cpp/pull/7544
struct llama_context_params {
- uint32_t seed; // RNG seed, -1 for random
uint32_t n_ctx; // text context, 0 = from model
uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
uint32_t n_ubatch; // physical maximum batch size
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
- // Keep the booleans together to avoid misalignment during copy-by-value.
+ // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
+ // TODO: move at the end of the struct
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
+ //bool no_perf; // whether to measure performance timings, TODO: implement
// Abort callback
// if it returns true, execution of llama_decode() will be aborted
void * kv_overrides; // pointer to vector containing overrides
} llama_model_quantize_params;
- // grammar types
- struct llama_grammar;
-
- // grammar element type
- enum llama_gretype {
- // end of rule definition
- LLAMA_GRETYPE_END = 0,
-
- // start of alternate definition for rule
- LLAMA_GRETYPE_ALT = 1,
-
- // non-terminal element: reference to rule
- LLAMA_GRETYPE_RULE_REF = 2,
-
- // terminal element: character (code point)
- LLAMA_GRETYPE_CHAR = 3,
-
- // inverse char(s) ([^a], [^a-b] [^abc])
- LLAMA_GRETYPE_CHAR_NOT = 4,
-
- // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
- // be an inclusive range ([a-z])
- LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
-
- // modifies a preceding LLAMA_GRETYPE_CHAR or
- // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
- LLAMA_GRETYPE_CHAR_ALT = 6,
-
- // any character (.)
- LLAMA_GRETYPE_CHAR_ANY = 7,
- };
+ typedef struct llama_logit_bias {
+ llama_token token;
+ float bias;
+ } llama_logit_bias;
- typedef struct llama_grammar_element {
- enum llama_gretype type;
- uint32_t value; // Unicode code point or rule ID
- } llama_grammar_element;
-
- // performance timing information
- struct llama_timings {
- double t_start_ms;
- double t_end_ms;
- double t_load_ms;
- double t_sample_ms;
- double t_p_eval_ms;
- double t_eval_ms;
-
- int32_t n_sample;
- int32_t n_p_eval;
- int32_t n_eval;
- };
+ typedef struct llama_sampler_chain_params {
+ bool no_perf; // whether to measure performance timings
+ } llama_sampler_chain_params;
// used in chat template
typedef struct llama_chat_message {
struct llama_lora_adapter;
// Helpers for getting default parameters
- LLAMA_API struct llama_model_params llama_model_default_params(void);
- LLAMA_API struct llama_context_params llama_context_default_params(void);
+ // TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
+ LLAMA_API struct llama_model_params llama_model_default_params(void);
+ LLAMA_API struct llama_context_params llama_context_default_params(void);
+ LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void);
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
// Initialize the llama + ggml backend
LLAMA_API struct llama_model * llama_load_model_from_file(
const char * path_model,
- struct llama_model_params params);
+ struct llama_model_params params);
LLAMA_API void llama_free_model(struct llama_model * model);
+ // TODO: rename to llama_init_from_model
LLAMA_API struct llama_context * llama_new_context_with_model(
struct llama_model * model,
struct llama_context_params params);
LLAMA_API bool llama_supports_mlock (void);
LLAMA_API bool llama_supports_gpu_offload(void);
- LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
-
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
- LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
-
- LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
- LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
-
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
+ LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
+
+ LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
+ LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
+ LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
+
// Get the model's RoPE frequency scaling factor
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
//
// Returns the *actual* size in bytes of the state
- // (rng, logits, embedding and kv_cache)
+ // (logits, embedding and kv_cache)
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
int32_t length);
//
- // Grammar
+ // Sampling API
+ //
+ // Sample usage:
+ //
+ // // prepare the sampling chain at the start
+ // auto sparams = llama_sampler_chain_default_params();
+ //
+ // llama_sampler * smpl = llama_sampler_chain_init(sparams);
+ //
+ // llama_sampler_chain_add(smpl, llama_sampler_init_top_k(50));
+ // llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1));
+ // llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.8));
+ //
+ // // typically, the chain should end with a sampler such as "greedy", "dist" or "mirostat"
+ // // this sampler will be responsible to select the actual token
+ // llama_sampler_chain_add(smpl, llama_sampler_init_dist(seed));
+ //
+ // ...
+ //
+ // // decoding loop:
+ // while (...) {
+ // ...
+ //
+ // llama_decode(ctx, batch);
+ //
+ // // sample from the logits of the last token in the batch
+ // const llama_token id = llama_sampler_sample(smpl, ctx, -1);
+ //
+ // // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.)
+ // llama_sampler_accept(smpl, id);
+ // ...
+ // }
+ //
+ // llama_sampler_free(smpl);
+ //
+ // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
+ // TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab
//
- /// Initialize a llama_grammar.
- ///
- /// @param rules The rule elements of the grammar to initialize.
- /// @param n_rules The number of rules.
- /// @param start_rule_index The index of the root rule (the starting point of the grammar).
- /// @return The initialized llama_grammar or nullptr if initialization failed.
- LLAMA_API struct llama_grammar * llama_grammar_init(
- const llama_grammar_element ** rules,
- size_t n_rules,
- size_t start_rule_index);
-
- LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
-
- LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
-
- /// @details Apply constraints from grammar
- LLAMA_API void llama_grammar_sample(
- const struct llama_grammar * grammar,
- const struct llama_context * ctx,
- llama_token_data_array * candidates);
- LLAMA_API DEPRECATED(void llama_sample_grammar(
- struct llama_context * ctx,
- llama_token_data_array * candidates,
- const struct llama_grammar * grammar),
- "use llama_grammar_sample instead");
+ typedef void * llama_sampler_context_t;
- /// @details Accepts the sampled token into the grammar
- LLAMA_API void llama_grammar_accept_token(
- struct llama_grammar * grammar,
- struct llama_context * ctx,
- llama_token token);
+ // user code can implement the interface below in order to create custom llama_sampler
+ struct llama_sampler_i {
+ const char * (*name) (const struct llama_sampler * smpl); // can be NULL
+ void (*accept)( struct llama_sampler * smpl, llama_token token); // can be NULL
+ void (*apply) ( struct llama_sampler * smpl, llama_token_data_array * cur_p); // required
+ void (*reset) ( struct llama_sampler * smpl); // can be NULL
+ struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL
+ void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL
- //
- // Sampling functions
- //
+ // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
+ //void (*apply_ggml) (struct llama_sampler * smpl, ...);
+ };
- // Sets the current rng seed.
- LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
+ struct llama_sampler {
+ struct llama_sampler_i * iface;
+ llama_sampler_context_t ctx;
+ };
- /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
- /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
- LLAMA_API void llama_sample_repetition_penalties(
- struct llama_context * ctx,
- llama_token_data_array * candidates,
- const llama_token * last_tokens,
- size_t penalty_last_n,
- float penalty_repeat,
- float penalty_freq,
- float penalty_present);
-
- /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
- /// @param logits Logits extracted from the original generation context.
- /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
- /// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
- LLAMA_API void llama_sample_apply_guidance(
- struct llama_context * ctx,
- float * logits,
- float * logits_guidance,
- float scale);
+ // mirror of llama_sampler_i:
+ LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
+ LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
+ LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
+ LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl);
+ LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl);
+ // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
+ LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
+
+ // llama_sampler_chain
+ // a type of llama_sampler that can chain multiple samplers one after another
+
+ LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params);
+
+ // important: takes ownership of the sampler object and will free it when llama_sampler_free is called
+ LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl);
+ LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
+ LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain);
+
+ // available samplers:
+
+ LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void);
+ LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
- LLAMA_API void llama_sample_softmax(
- struct llama_context * ctx,
- llama_token_data_array * candidates);
+ LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void);
/// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
- LLAMA_API void llama_sample_top_k(
- struct llama_context * ctx,
- llama_token_data_array * candidates,
- int32_t k,
- size_t min_keep);
+ LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);
/// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
- LLAMA_API void llama_sample_top_p(
- struct llama_context * ctx,
- llama_token_data_array * candidates,
- float p,
- size_t min_keep);
+ LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep);
/// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
- LLAMA_API void llama_sample_min_p(
- struct llama_context * ctx,
- llama_token_data_array * candidates,
- float p,
- size_t min_keep);
+ LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep);
/// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
- LLAMA_API void llama_sample_tail_free(
- struct llama_context * ctx,
- llama_token_data_array * candidates,
- float z,
- size_t min_keep);
+ LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, size_t min_keep);
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
- LLAMA_API void llama_sample_typical(
- struct llama_context * ctx,
- llama_token_data_array * candidates,
- float p,
- size_t min_keep);
+ LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
+ LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
- /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
- LLAMA_API void llama_sample_entropy(
- struct llama_context * ctx,
- llama_token_data_array * candidates_p,
- float min_temp,
- float max_temp,
- float exponent_val);
-
- LLAMA_API void llama_sample_temp(
- struct llama_context * ctx,
- llama_token_data_array * candidates,
- float temp);
+ /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
+ LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
- LLAMA_API llama_token llama_sample_token_mirostat(
- struct llama_context * ctx,
- llama_token_data_array * candidates,
- float tau,
- float eta,
- int32_t m,
- float * mu);
+ LLAMA_API struct llama_sampler * llama_sampler_init_mirostat(
+ int32_t n_vocab,
+ uint32_t seed,
+ float tau,
+ float eta,
+ int32_t m);
/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
- LLAMA_API llama_token llama_sample_token_mirostat_v2(
- struct llama_context * ctx,
- llama_token_data_array * candidates,
- float tau,
- float eta,
- float * mu);
-
- /// @details Selects the token with the highest probability.
- /// Does not compute the token probabilities. Use llama_sample_softmax() instead.
- LLAMA_API llama_token llama_sample_token_greedy(
- struct llama_context * ctx,
- llama_token_data_array * candidates);
+ LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2(
+ uint32_t seed,
+ float tau,
+ float eta);
+
+ LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
+ const struct llama_model * model,
+ const char * grammar_str,
+ const char * grammar_root);
+
+ LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
+ int32_t n_vocab, // llama_n_vocab()
+ llama_token special_eos_id, // llama_token_eos()
+ llama_token linefeed_id, // llama_token_nl()
+ int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size)
+ float penalty_repeat, // 1.0 = disabled
+ float penalty_freq, // 0.0 = disabled
+ float penalty_present, // 0.0 = disabled
+ bool penalize_nl, // consider newlines as a repeatable token
+ bool ignore_eos); // ignore the end-of-sequence token
+
+ LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
+ int32_t n_vocab,
+ int32_t n_logit_bias,
+ const llama_logit_bias * logit_bias);
+
+ // Shorthand for:
+ //
+ // const auto * logits = llama_get_logits_ith(ctx, idx);
+ // llama_token_data_array cur_p = { ... init from logits ... };
+ // llama_sampler_apply(smpl, &cur_p);
+ // return cur_p.data[cur_p.selected].id;
+ //
+ // At this point, this is mostly a convenience function.
+ //
+ LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);
- /// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
- LLAMA_API llama_token llama_sample_token(
- struct llama_context * ctx,
- llama_token_data_array * candidates);
+ // TODO: extend in the future
+ //LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...);
//
// Model split
// Returns the split_prefix length.
LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count);
- // Performance information
- LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
-
- LLAMA_API void llama_print_timings(struct llama_context * ctx);
- LLAMA_API void llama_reset_timings(struct llama_context * ctx);
-
// Print system information
LLAMA_API const char * llama_print_system_info(void);
// If this is not called, or NULL is supplied, everything is output on stderr.
LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
- LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
-
-#ifdef __cplusplus
-}
-#endif
-
-// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
-#ifdef LLAMA_API_INTERNAL
-
-#include <random>
-#include <string>
-#include <vector>
-
-struct ggml_tensor;
-
-const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
- struct llama_context * ctx
-);
-
-struct llama_partial_utf8 {
- uint32_t value; // bit value so far (unshifted)
- int n_remain; // num bytes remaining; -1 indicates invalid sequence
-};
-
-struct llama_grammar_candidate {
- size_t index;
- const uint32_t * code_points;
- llama_partial_utf8 partial_utf8;
-};
-
-using llama_grammar_rule = std::vector< llama_grammar_element>;
-using llama_grammar_stack = std::vector<const llama_grammar_element *>;
-
-using llama_grammar_rules = std::vector<llama_grammar_rule>;
-using llama_grammar_stacks = std::vector<llama_grammar_stack>;
-using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
-
-const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
- llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
-
-void llama_grammar_accept(
- const llama_grammar_rules & rules,
- const llama_grammar_stacks & stacks,
- const uint32_t chr,
- llama_grammar_stacks & new_stacks);
+ //
+ // Performance utils
+ //
+ // NOTE: Used by llama.cpp examples, avoid using in third-party apps. Instead, do your own performance measurements.
+ //
-std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
- const llama_grammar_rules & rules,
- const llama_grammar_stack & stack,
- const llama_grammar_candidates & candidates);
+ enum llama_perf_type {
+ LLAMA_PERF_TYPE_CONTEXT = 0,
+ LLAMA_PERF_TYPE_SAMPLER_CHAIN = 1,
+ };
-std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
- const std::string & src,
- llama_partial_utf8 partial_start);
+ LLAMA_API void llama_perf_print(const void * ctx, enum llama_perf_type type);
+ LLAMA_API void llama_perf_reset( void * ctx, enum llama_perf_type type);
-// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
-// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
-llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
+ LLAMA_API void llama_perf_dump_yaml(FILE * stream, const struct llama_context * ctx);
-#endif // LLAMA_API_INTERNAL
+#ifdef __cplusplus
+}
+#endif
#endif // LLAMA_H
#include "llama-vocab.h"
#include "llama-sampling.h"
+#include <cmath>
#include <algorithm>
+#include <stdexcept>
-// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
-// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
-std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
+//
+// helpers
+//
+
+// NOTE: assumes valid utf8 (but checks for overrun)
+static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
+ uint8_t first_byte = static_cast<uint8_t>(*src);
+ uint8_t highbits = first_byte >> 4;
+ int len = lookup[highbits];
+ uint8_t mask = (1 << (8 - len)) - 1;
+ uint32_t value = first_byte & mask;
+ const char * end = src + len; // may overrun!
+ const char * pos = src + 1;
+ for ( ; pos < end && *pos; pos++) {
+ value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
+ }
+ return std::make_pair(value, pos);
+}
+
+static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
const std::string & src,
llama_partial_utf8 partial_start) {
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
while (*pos != 0) {
uint8_t first_byte = static_cast<uint8_t>(*pos);
uint8_t highbits = first_byte >> 4;
- n_remain = lookup[highbits] - 1;
+ n_remain = lookup[highbits] - 1;
if (n_remain < 0) {
// invalid sequence, abort
}
uint8_t mask = (1 << (7 - n_remain)) - 1;
- value = first_byte & mask;
+ value = first_byte & mask;
++pos;
while (*pos != 0 && n_remain > 0) {
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
}
-const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
- return grammar->rules;
+static bool is_digit_char(char c) {
+ return '0' <= c && c <= '9';
}
-llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
- return grammar->stacks;
+static bool is_word_char(char c) {
+ return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
+}
+
+static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
+ const char * pos = src;
+ const char * end = src + size;
+ uint32_t value = 0;
+ for ( ; pos < end && *pos; pos++) {
+ value <<= 4;
+ char c = *pos;
+ if ('a' <= c && c <= 'f') {
+ value += c - 'a' + 10;
+ } else if ('A' <= c && c <= 'F') {
+ value += c - 'A' + 10;
+ } else if ('0' <= c && c <= '9') {
+ value += c - '0';
+ } else {
+ break;
+ }
+ }
+ if (pos != end) {
+ throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
+ }
+ return std::make_pair(value, pos);
+}
+
+static const char * parse_space(const char * src, bool newline_ok) {
+ const char * pos = src;
+ while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
+ (newline_ok && (*pos == '\r' || *pos == '\n'))) {
+ if (*pos == '#') {
+ while (*pos && *pos != '\r' && *pos != '\n') {
+ pos++;
+ }
+ } else {
+ pos++;
+ }
+ }
+ return pos;
+}
+
+static const char * parse_name(const char * src) {
+ const char * pos = src;
+ while (is_word_char(*pos)) {
+ pos++;
+ }
+ if (pos == src) {
+ throw std::runtime_error(std::string("expecting name at ") + src);
+ }
+ return pos;
+}
+
+static const char * parse_int(const char * src) {
+ const char * pos = src;
+ while (is_digit_char(*pos)) {
+ pos++;
+ }
+ if (pos == src) {
+ throw std::runtime_error(std::string("expecting integer at ") + src);
+ }
+ return pos;
+}
+
+static std::pair<uint32_t, const char *> parse_char(const char * src) {
+ if (*src == '\\') {
+ switch (src[1]) {
+ case 'x': return parse_hex(src + 2, 2);
+ case 'u': return parse_hex(src + 2, 4);
+ case 'U': return parse_hex(src + 2, 8);
+ case 't': return std::make_pair('\t', src + 2);
+ case 'r': return std::make_pair('\r', src + 2);
+ case 'n': return std::make_pair('\n', src + 2);
+ case '\\':
+ case '"':
+ case '[':
+ case ']':
+ return std::make_pair(src[1], src + 2);
+ default:
+ throw std::runtime_error(std::string("unknown escape at ") + src);
+ }
+ } else if (*src) {
+ return decode_utf8(src);
+ }
+ throw std::runtime_error("unexpected end of input");
+}
+
+static void print_grammar_char(FILE * file, uint32_t c) {
+ if (0x20 <= c && c <= 0x7f) {
+ fprintf(file, "%c", static_cast<char>(c));
+ } else {
+ // cop out of encoding UTF-8
+ fprintf(file, "<U+%04X>", c);
+ }
+}
+
+static bool is_char_element(llama_grammar_element elem) {
+ switch (elem.type) {
+ case LLAMA_GRETYPE_CHAR: return true;
+ case LLAMA_GRETYPE_CHAR_NOT: return true;
+ case LLAMA_GRETYPE_CHAR_ALT: return true;
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
+ case LLAMA_GRETYPE_CHAR_ANY: return true;
+ default: return false;
+ }
+}
+
+static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
+ for (auto elem : rule) {
+ switch (elem.type) {
+ case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
+ case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
+ case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
+ case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
+ case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
+ case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
+ case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break;
+ }
+ switch (elem.type) {
+ case LLAMA_GRETYPE_END:
+ case LLAMA_GRETYPE_ALT:
+ case LLAMA_GRETYPE_RULE_REF:
+ fprintf(file, "(%u) ", elem.value);
+ break;
+ case LLAMA_GRETYPE_CHAR:
+ case LLAMA_GRETYPE_CHAR_NOT:
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+ case LLAMA_GRETYPE_CHAR_ALT:
+ case LLAMA_GRETYPE_CHAR_ANY:
+ fprintf(file, "(\"");
+ print_grammar_char(file, elem.value);
+ fprintf(file, "\") ");
+ break;
+ }
+ }
+ fprintf(file, "\n");
+}
+
+static void print_rule(
+ FILE * file,
+ uint32_t rule_id,
+ const llama_grammar_rule & rule,
+ const std::map<uint32_t, std::string> & symbol_id_names) {
+ if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
+ throw std::runtime_error(
+ "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
+ }
+ fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
+ for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
+ llama_grammar_element elem = rule[i];
+ switch (elem.type) {
+ case LLAMA_GRETYPE_END:
+ throw std::runtime_error(
+ "unexpected end of rule: " + std::to_string(rule_id) + "," +
+ std::to_string(i));
+ case LLAMA_GRETYPE_ALT:
+ fprintf(file, "| ");
+ break;
+ case LLAMA_GRETYPE_RULE_REF:
+ fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
+ break;
+ case LLAMA_GRETYPE_CHAR:
+ fprintf(file, "[");
+ print_grammar_char(file, elem.value);
+ break;
+ case LLAMA_GRETYPE_CHAR_NOT:
+ fprintf(file, "[^");
+ print_grammar_char(file, elem.value);
+ break;
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+ if (i == 0 || !is_char_element(rule[i - 1])) {
+ throw std::runtime_error(
+ "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
+ std::to_string(rule_id) + "," + std::to_string(i));
+ }
+ fprintf(file, "-");
+ print_grammar_char(file, elem.value);
+ break;
+ case LLAMA_GRETYPE_CHAR_ALT:
+ if (i == 0 || !is_char_element(rule[i - 1])) {
+ throw std::runtime_error(
+ "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
+ std::to_string(rule_id) + "," + std::to_string(i));
+ }
+ print_grammar_char(file, elem.value);
+ break;
+ case LLAMA_GRETYPE_CHAR_ANY:
+ fprintf(file, ".");
+ break;
+ }
+ if (is_char_element(elem)) {
+ switch (rule[i + 1].type) {
+ case LLAMA_GRETYPE_CHAR_ALT:
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+ case LLAMA_GRETYPE_CHAR_ANY:
+ break;
+ default:
+ fprintf(file, "] ");
+ }
+ }
+ }
+ fprintf(file, "\n");
+}
+
+//
+// implementation
+//
+
+uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) {
+ uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
+ auto result = symbol_ids.emplace(std::string(src, len), next_id);
+ return result.first->second;
+}
+
+uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) {
+ uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
+ symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
+ return next_id;
+}
+
+void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) {
+ if (rules.size() <= rule_id) {
+ rules.resize(rule_id + 1);
+ }
+ rules[rule_id] = rule;
+}
+
+const char * llama_grammar_parser::parse_alternates(
+ const char * src,
+ const std::string & rule_name,
+ uint32_t rule_id,
+ bool is_nested) {
+ llama_grammar_rule rule;
+ const char * pos = parse_sequence(src, rule_name, rule, is_nested);
+ while (*pos == '|') {
+ rule.push_back({LLAMA_GRETYPE_ALT, 0});
+ pos = parse_space(pos + 1, true);
+ pos = parse_sequence(pos, rule_name, rule, is_nested);
+ }
+ rule.push_back({LLAMA_GRETYPE_END, 0});
+ add_rule(rule_id, rule);
+ return pos;
+}
+
+const char * llama_grammar_parser::parse_sequence(
+ const char * src,
+ const std::string & rule_name,
+ llama_grammar_rule & rule,
+ bool is_nested) {
+ size_t last_sym_start = rule.size();
+ const char * pos = src;
+
+ auto handle_repetitions = [&](int min_times, int max_times) {
+
+ if (last_sym_start == rule.size()) {
+ throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
+ }
+
+ // apply transformation to previous symbol (last_sym_start to end) according to
+ // the following rewrite rules:
+ // S{m,n} --> S S S (m times) S'(n-m)
+ // S'(x) ::= S S'(x-1) |
+ // (... n-m definitions of these S' rules ...)
+ // S'(1) ::= S |
+ // S{m,} --> S S S (m times) S'
+ // S' ::= S S' |
+ // S* --> S{0,}
+ // --> S' ::= S S' |
+ // S+ --> S{1,}
+ // --> S S'
+ // S' ::= S S' |
+ // S? --> S{0,1}
+ // --> S'
+ // S' ::= S |
+
+ llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
+ if (min_times == 0) {
+ rule.resize(last_sym_start);
+ } else {
+ // Repeat the previous elements (min_times - 1) times
+ for (int i = 1; i < min_times; i++) {
+ rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
+ }
+ }
+
+ uint32_t last_rec_rule_id = 0;
+ auto n_opt = max_times < 0 ? 1 : max_times - min_times;
+
+ llama_grammar_rule rec_rule(prev_rule);
+ for (int i = 0; i < n_opt; i++) {
+ rec_rule.resize(prev_rule.size());
+ uint32_t rec_rule_id = generate_symbol_id( rule_name);
+ if (i > 0 || max_times < 0) {
+ rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
+ }
+ rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
+ rec_rule.push_back({LLAMA_GRETYPE_END, 0});
+ add_rule( rec_rule_id, rec_rule);
+ last_rec_rule_id = rec_rule_id;
+ }
+ if (n_opt > 0) {
+ rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
+ }
+ };
+
+ while (*pos) {
+ if (*pos == '"') { // literal string
+ pos++;
+ last_sym_start = rule.size();
+ while (*pos != '"') {
+ if (!*pos) {
+ throw std::runtime_error("unexpected end of input");
+ }
+ auto char_pair = parse_char(pos);
+ pos = char_pair.second;
+ rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
+ }
+ pos = parse_space(pos + 1, is_nested);
+ } else if (*pos == '[') { // char range(s)
+ pos++;
+ enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
+ if (*pos == '^') {
+ pos++;
+ start_type = LLAMA_GRETYPE_CHAR_NOT;
+ }
+ last_sym_start = rule.size();
+ while (*pos != ']') {
+ if (!*pos) {
+ throw std::runtime_error("unexpected end of input");
+ }
+ auto char_pair = parse_char(pos);
+ pos = char_pair.second;
+ enum llama_gretype type = last_sym_start < rule.size()
+ ? LLAMA_GRETYPE_CHAR_ALT
+ : start_type;
+
+ rule.push_back({type, char_pair.first});
+ if (pos[0] == '-' && pos[1] != ']') {
+ if (!pos[1]) {
+ throw std::runtime_error("unexpected end of input");
+ }
+ auto endchar_pair = parse_char(pos + 1);
+ pos = endchar_pair.second;
+ rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
+ }
+ }
+ pos = parse_space(pos + 1, is_nested);
+ } else if (is_word_char(*pos)) { // rule reference
+ const char * name_end = parse_name(pos);
+ uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
+ pos = parse_space(name_end, is_nested);
+ last_sym_start = rule.size();
+ rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
+ } else if (*pos == '(') { // grouping
+ // parse nested alternates into synthesized rule
+ pos = parse_space(pos + 1, true);
+ uint32_t sub_rule_id = generate_symbol_id(rule_name);
+ pos = parse_alternates(pos, rule_name, sub_rule_id, true);
+ last_sym_start = rule.size();
+ // output reference to synthesized rule
+ rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
+ if (*pos != ')') {
+ throw std::runtime_error(std::string("expecting ')' at ") + pos);
+ }
+ pos = parse_space(pos + 1, is_nested);
+ } else if (*pos == '.') { // any char
+ last_sym_start = rule.size();
+ rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
+ pos = parse_space(pos + 1, is_nested);
+ } else if (*pos == '*') {
+ pos = parse_space(pos + 1, is_nested);
+ handle_repetitions(0, -1);
+ } else if (*pos == '+') {
+ pos = parse_space(pos + 1, is_nested);
+ handle_repetitions(1, -1);
+ } else if (*pos == '?') {
+ pos = parse_space(pos + 1, is_nested);
+ handle_repetitions(0, 1);
+ } else if (*pos == '{') {
+ pos = parse_space(pos + 1, is_nested);
+
+ if (!is_digit_char(*pos)) {
+ throw std::runtime_error(std::string("expecting an int at ") + pos);
+ }
+ const char * int_end = parse_int(pos);
+ int min_times = std::stoul(std::string(pos, int_end - pos));
+ pos = parse_space(int_end, is_nested);
+
+ int max_times = -1;
+
+ if (*pos == '}') {
+ max_times = min_times;
+ pos = parse_space(pos + 1, is_nested);
+ } else if (*pos == ',') {
+ pos = parse_space(pos + 1, is_nested);
+
+ if (is_digit_char(*pos)) {
+ const char * int_end = parse_int(pos);
+ max_times = std::stoul(std::string(pos, int_end - pos));
+ pos = parse_space(int_end, is_nested);
+ }
+
+ if (*pos != '}') {
+ throw std::runtime_error(std::string("expecting '}' at ") + pos);
+ }
+ pos = parse_space(pos + 1, is_nested);
+ } else {
+ throw std::runtime_error(std::string("expecting ',' at ") + pos);
+ }
+ handle_repetitions(min_times, max_times);
+ } else {
+ break;
+ }
+ }
+ return pos;
+ }
+
+const char * llama_grammar_parser::parse_rule(const char * src) {
+ const char * name_end = parse_name(src);
+ const char * pos = parse_space(name_end, false);
+ size_t name_len = name_end - src;
+ uint32_t rule_id = get_symbol_id(src, name_len);
+ const std::string name(src, name_len);
+
+ if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
+ throw std::runtime_error(std::string("expecting ::= at ") + pos);
+ }
+ pos = parse_space(pos + 3, true);
+
+ pos = parse_alternates(pos, name, rule_id, false);
+
+ if (*pos == '\r') {
+ pos += pos[1] == '\n' ? 2 : 1;
+ } else if (*pos == '\n') {
+ pos++;
+ } else if (*pos) {
+ throw std::runtime_error(std::string("expecting newline or end at ") + pos);
+ }
+ return parse_space(pos, true);
+ }
+
+bool llama_grammar_parser::parse(const char * src) {
+ try {
+ const char * pos = parse_space(src, true);
+ while (*pos) {
+ pos = parse_rule(pos);
+ }
+ // Validate the state to ensure that all rules are defined
+ for (const auto & rule : rules) {
+ if (rule.empty()) {
+ throw std::runtime_error("Undefined rule");
+ }
+ for (const auto & elem : rule) {
+ if (elem.type == LLAMA_GRETYPE_RULE_REF) {
+ // Ensure that the rule at that location exists
+ if (elem.value >= rules.size() || rules[elem.value].empty()) {
+ // Get the name of the rule that is missing
+ for (const auto & kv : symbol_ids) {
+ if (kv.second == elem.value) {
+ throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
+ }
+ }
+ }
+ }
+ }
+ }
+ } catch (const std::exception & err) {
+ fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
+ rules.clear();
+ return false;
+ }
+
+ return true;
+}
+
+void llama_grammar_parser::print(FILE * file) {
+ try {
+ std::map<uint32_t, std::string> symbol_id_names;
+ for (const auto & kv : symbol_ids) {
+ symbol_id_names[kv.second] = kv.first;
+ }
+ for (size_t i = 0, end = rules.size(); i < end; i++) {
+ // fprintf(file, "%zu: ", i);
+ // print_rule_binary(file, rules[i]);
+ print_rule(file, uint32_t(i), rules[i], symbol_id_names);
+ // fprintf(file, "\n");
+ }
+ } catch (const std::exception & err) {
+ fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
+ }
+}
+
+llama_grammar_stack llama_grammar_parser::c_rules() const {
+ llama_grammar_stack ret;
+ ret.reserve(rules.size());
+ for (const auto & rule : rules) {
+ ret.push_back(rule.data());
+ }
+ return ret;
}
// returns true iff pos points to the end of one of the definitions of a rule
static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
const llama_grammar_element * pos,
const uint32_t chr) {
-
bool found = false;
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
}
}
-// takes a set of possible pushdown stacks on a grammar, which are required to
-// be positioned at a character range (see `llama_grammar_advance_stack`), and
-// produces the N possible stacks if the given char is accepted at those
-// positions
+static llama_grammar_candidates llama_grammar_reject_candidates(
+ const llama_grammar_rules & rules,
+ const llama_grammar_stacks & stacks,
+ const llama_grammar_candidates & candidates) {
+ GGML_ASSERT(!stacks.empty()); // REVIEW
+
+ if (candidates.empty()) {
+ return {};
+ }
+
+ auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
+
+ for (size_t i = 1, size = stacks.size(); i < size; ++i) {
+ rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
+ }
+
+ return rejects;
+}
+
+static bool llama_grammar_detect_left_recursion(
+ const llama_grammar_rules & rules,
+ size_t rule_index,
+ std::vector<bool> * rules_visited,
+ std::vector<bool> * rules_in_progress,
+ std::vector<bool> * rules_may_be_empty) {
+ if ((*rules_in_progress)[rule_index]) {
+ return true;
+ }
+
+ (*rules_in_progress)[rule_index] = true;
+
+ const llama_grammar_rule & rule = rules[rule_index];
+
+ // First check if the rule might produce the empty string. This could be done combined with the second
+ // step but it's more readable as two steps.
+ bool at_rule_start = true;
+ for (size_t i = 0; i < rule.size(); i++) {
+ if (llama_grammar_is_end_of_sequence(&rule[i])) {
+ if (at_rule_start) {
+ (*rules_may_be_empty)[rule_index] = true;
+ break;
+ }
+ at_rule_start = true;
+ } else {
+ at_rule_start = false;
+ }
+ }
+
+ // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
+ // be empty)
+ bool recurse_into_nonterminal = true;
+ for (size_t i = 0; i < rule.size(); i++) {
+ if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
+ if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
+ return true;
+ }
+ if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
+ recurse_into_nonterminal = false;
+ }
+ } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
+ recurse_into_nonterminal = true;
+ } else {
+ recurse_into_nonterminal = false;
+ }
+ }
+
+ (*rules_in_progress)[rule_index] = false;
+ (*rules_visited)[rule_index] = true;
+
+ return false;
+}
+
+const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
+ return grammar->rules;
+}
+
+llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
+ return grammar->stacks;
+}
+
void llama_grammar_accept(
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
const uint32_t chr,
- llama_grammar_stacks & new_stacks) {
- new_stacks.clear();
+ llama_grammar_stacks & stacks_new) {
+ stacks_new.clear();
+ stacks_new.reserve(stacks.size());
for (const auto & stack : stacks) {
if (stack.empty()) {
if (!llama_grammar_is_end_of_sequence(pos)) {
new_stack.push_back(pos);
}
- llama_grammar_advance_stack(rules, new_stack, new_stacks);
+ llama_grammar_advance_stack(rules, new_stack, stacks_new);
}
}
}
-static llama_grammar_candidates llama_grammar_reject_candidates(
- const llama_grammar_rules & rules,
- const llama_grammar_stacks & stacks,
- const llama_grammar_candidates & candidates) {
- GGML_ASSERT(!stacks.empty()); // REVIEW
-
- if (candidates.empty()) {
- return {};
- }
-
- auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
-
- for (size_t i = 1, size = stacks.size(); i < size; ++i) {
- rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
- }
- return rejects;
-}
-
llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
const llama_grammar_rules & rules,
const llama_grammar_stack & stack,
return rejects;
}
-static bool llama_grammar_detect_left_recursion(
- const llama_grammar_rules & rules,
- size_t rule_index,
- std::vector<bool> * rules_visited,
- std::vector<bool> * rules_in_progress,
- std::vector<bool> * rules_may_be_empty) {
- if ((*rules_in_progress)[rule_index]) {
- return true;
- }
+////////////////////
- (*rules_in_progress)[rule_index] = true;
+struct llama_grammar * llama_grammar_init_impl(
+ const struct llama_vocab * vocab,
+ const llama_grammar_element ** rules,
+ size_t n_rules,
+ size_t start_rule_index) {
+ const llama_grammar_element * pos;
- const llama_grammar_rule & rule = rules[rule_index];
+ // copy rule definitions into vectors
+ llama_grammar_rules vec_rules(n_rules);
+ for (size_t i = 0; i < n_rules; i++) {
+ for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
+ vec_rules[i].push_back(*pos);
+ }
+ vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
+ }
- // First check if the rule might produce the empty string. This could be done combined with the second
- // step but it's more readable as two steps.
- bool at_rule_start = true;
- for (size_t i = 0; i < rule.size(); i++) {
- if (llama_grammar_is_end_of_sequence(&rule[i])) {
- if (at_rule_start) {
- (*rules_may_be_empty)[rule_index] = true;
- break;
- }
- at_rule_start = true;
- } else {
- at_rule_start = false;
+ // Check for left recursion
+ std::vector<bool> rules_visited(n_rules);
+ std::vector<bool> rules_in_progress(n_rules);
+ std::vector<bool> rules_may_be_empty(n_rules);
+ for (size_t i = 0; i < n_rules; i++) {
+ if (rules_visited[i]) {
+ continue;
+ }
+ if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
+ LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
+ return nullptr;
}
}
- // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
- // be empty)
- bool recurse_into_nonterminal = true;
- for (size_t i = 0; i < rule.size(); i++) {
- if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
- if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
- return true;
- }
- if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
- recurse_into_nonterminal = false;
- }
- } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
- recurse_into_nonterminal = true;
+ // loop over alternates of start rule to build initial stacks
+ llama_grammar_stacks stacks;
+ pos = vec_rules[start_rule_index].data();
+ do {
+ llama_grammar_stack stack;
+ if (!llama_grammar_is_end_of_sequence(pos)) {
+ // if alternate is nonempty, add to stack
+ stack.push_back(pos);
+ }
+ llama_grammar_advance_stack(vec_rules, stack, stacks);
+ while (!llama_grammar_is_end_of_sequence(pos)) {
+ // scan to end of alternate def
+ pos++;
+ }
+ if (pos->type == LLAMA_GRETYPE_ALT) {
+ // there's another alternate def of this rule to process
+ pos++;
} else {
- recurse_into_nonterminal = false;
+ break;
}
- }
+ } while (true);
- (*rules_in_progress)[rule_index] = false;
- (*rules_visited)[rule_index] = true;
- return false;
+ // Important: vec_rules has to be moved here, not copied, because stacks contains
+ // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
+ // then the pointers would be invalidated when the local vec_rules goes out of scope.
+ return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
}
-//
-// grammar - external
-//
+struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
+ llama_grammar_parser parser;
+
+ // if there is a grammar, parse it
+ if (!parser.parse(grammar_str)) {
+ return nullptr;
+ }
+
+ // will be empty (default) if there are parse errors
+ if (parser.rules.empty()) {
+ fprintf(stderr, "%s: failed to parse grammar\n", __func__);
+ return nullptr;
+ }
+
+ // Ensure that there is a "root" node.
+ if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) {
+ fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
+ return nullptr;
+ }
+
+ std::vector<const llama_grammar_element *> grammar_rules(parser.c_rules());
+
+ const size_t n_rules = grammar_rules.size();
+ const size_t start_rule_index = parser.symbol_ids.at(grammar_root);
-struct llama_grammar * llama_grammar_init_impl(
- const llama_grammar_element ** rules,
- size_t n_rules,
- size_t start_rule_index) {
const llama_grammar_element * pos;
// copy rule definitions into vectors
llama_grammar_rules vec_rules(n_rules);
for (size_t i = 0; i < n_rules; i++) {
- for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
+ for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
vec_rules[i].push_back(*pos);
}
vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
// Important: vec_rules has to be moved here, not copied, because stacks contains
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
// then the pointers would be invalidated when the local vec_rules goes out of scope.
- return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
+ return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
}
void llama_grammar_free_impl(struct llama_grammar * grammar) {
+ if (grammar == nullptr) {
+ return;
+ }
+
delete grammar;
}
-struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) {
- llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
+struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
+ llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
// redirect elements in stacks to point to new rules
for (size_t is = 0; is < result->stacks.size(); is++) {
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
- for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
- for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
- if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
+ for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
+ for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
+ if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
result->stacks[is][ie] = &result->rules[ir0][ir1];
}
}
return result;
}
-void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
- GGML_ASSERT(grammar);
- GGML_ASSERT(vocab);
-
- int64_t t_start_sample_us = ggml_time_us();
+void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
+ GGML_ASSERT(grammar.vocab != nullptr);
bool allow_eog = false;
- for (const auto & stack : grammar->stacks) {
+ for (const auto & stack : grammar.stacks) {
if (stack.empty()) {
allow_eog = true;
break;
}
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
- candidates_decoded.reserve(candidates->size);
+ candidates_decoded.reserve(cur_p->size);
llama_grammar_candidates candidates_grammar;
- candidates_grammar.reserve(candidates->size);
+ candidates_grammar.reserve(cur_p->size);
- for (size_t i = 0; i < candidates->size; ++i) {
- const llama_token id = candidates->data[i].id;
- const std::string & piece = vocab->cache_token_to_piece.at(id);
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ const llama_token id = cur_p->data[i].id;
+ const std::string & piece = grammar.vocab->cache_token_to_piece.at(id);
- if (llama_token_is_eog_impl(*vocab, id)) {
+ if (llama_token_is_eog_impl(*grammar.vocab, id)) {
if (!allow_eog) {
- candidates->data[i].logit = -INFINITY;
+ cur_p->data[i].logit = -INFINITY;
}
} else if (piece.empty() || piece[0] == 0) {
- candidates->data[i].logit = -INFINITY;
+ cur_p->data[i].logit = -INFINITY;
} else {
- candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
+ candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
}
}
- const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
+ const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
for (const auto & reject : rejects) {
- candidates->data[reject.index].logit = -INFINITY;
+ cur_p->data[reject.index].logit = -INFINITY;
}
-
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
}
-void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
- const int64_t t_start_sample_us = ggml_time_us();
+void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
+ GGML_ASSERT(grammar.vocab != nullptr);
- if (llama_token_is_eog_impl(*vocab, token)) {
- for (const auto & stack : grammar->stacks) {
+ if (llama_token_is_eog_impl(*grammar.vocab, token)) {
+ for (const auto & stack : grammar.stacks) {
if (stack.empty()) {
return;
}
GGML_ABORT("fatal error");
}
- const std::string & piece = vocab->cache_token_to_piece.at(token);
+ const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);
// Note terminating 0 in decoded string
- const auto decoded = decode_utf8(piece, grammar->partial_utf8);
+ const auto decoded = decode_utf8(piece, grammar.partial_utf8);
const auto & code_points = decoded.first;
- llama_grammar_stacks tmp_new_stacks;
+ llama_grammar_stacks stacks_new;
+
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
- llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
- grammar->stacks = tmp_new_stacks;
+ llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
+ grammar.stacks = std::move(stacks_new);
}
- grammar->partial_utf8 = decoded.second;
- GGML_ASSERT(!grammar->stacks.empty());
-
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+ grammar.partial_utf8 = decoded.second;
+ GGML_ASSERT(!grammar.stacks.empty());
}
#include "llama-impl.h"
+#include <map>
+
struct llama_vocab;
-struct llama_sampling;
+
+// grammar element type
+enum llama_gretype {
+ // end of rule definition
+ LLAMA_GRETYPE_END = 0,
+
+ // start of alternate definition for rule
+ LLAMA_GRETYPE_ALT = 1,
+
+ // non-terminal element: reference to rule
+ LLAMA_GRETYPE_RULE_REF = 2,
+
+ // terminal element: character (code point)
+ LLAMA_GRETYPE_CHAR = 3,
+
+ // inverse char(s) ([^a], [^a-b] [^abc])
+ LLAMA_GRETYPE_CHAR_NOT = 4,
+
+ // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
+ // be an inclusive range ([a-z])
+ LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
+
+ // modifies a preceding LLAMA_GRETYPE_CHAR or
+ // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
+ LLAMA_GRETYPE_CHAR_ALT = 6,
+
+ // any character (.)
+ LLAMA_GRETYPE_CHAR_ANY = 7,
+};
+
+typedef struct llama_grammar_element {
+ enum llama_gretype type;
+ uint32_t value; // Unicode code point or rule ID
+} llama_grammar_element;
+
+struct llama_partial_utf8 {
+ uint32_t value; // bit value so far (unshifted)
+ int n_remain; // num bytes remaining; -1 indicates invalid sequence
+};
+
+struct llama_grammar_candidate {
+ size_t index;
+ const uint32_t * code_points;
+ llama_partial_utf8 partial_utf8;
+};
+
+using llama_grammar_rule = std::vector< llama_grammar_element>;
+using llama_grammar_stack = std::vector<const llama_grammar_element *>;
+
+using llama_grammar_rules = std::vector<llama_grammar_rule>;
+using llama_grammar_stacks = std::vector<llama_grammar_stack>;
+using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
+
+const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
+ llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
+
+// takes a set of possible pushdown stacks on a grammar, which are required to
+// be positioned at a character range (see `llama_grammar_advance_stack`), and
+// produces the N possible stacks if the given char is accepted at those
+// positions
+void llama_grammar_accept(
+ const llama_grammar_rules & rules,
+ const llama_grammar_stacks & stacks,
+ uint32_t chr,
+ llama_grammar_stacks & stacks_new);
+
+std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
+ const llama_grammar_rules & rules,
+ const llama_grammar_stack & stack,
+ const llama_grammar_candidates & candidates);
+
+struct llama_grammar_parser {
+ std::map<std::string, uint32_t> symbol_ids;
+
+ llama_grammar_rules rules;
+
+ llama_grammar_stack c_rules() const;
+
+ uint32_t get_symbol_id(const char * src, size_t len);
+ uint32_t generate_symbol_id(const std::string & base_name);
+
+ void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
+
+ const char * parse_alternates(
+ const char * src,
+ const std::string & rule_name,
+ uint32_t rule_id,
+ bool is_nested);
+
+ const char * parse_sequence(
+ const char * src,
+ const std::string & rule_name,
+ llama_grammar_rule & rule,
+ bool is_nested);
+
+ const char * parse_rule(const char * src);
+
+ bool parse(const char * src);
+ void print(FILE * file);
+};
struct llama_grammar {
- const llama_grammar_rules rules;
+ // note: allow null vocab for testing (not great)
+ const llama_vocab * vocab;
+
+ const llama_grammar_rules rules; // TODO: shared ptr
llama_grammar_stacks stacks;
// buffer for partially generated UTF-8 sequence from accepted tokens
// internal API
//
+// note: needed for tests (not great)
struct llama_grammar * llama_grammar_init_impl(
- const llama_grammar_element ** rules,
- size_t n_rules,
- size_t start_rule_index);
+ const struct llama_vocab * vocab,
+ const llama_grammar_element ** rules,
+ size_t n_rules,
+ size_t start_rule_index);
+
+struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
void llama_grammar_free_impl(struct llama_grammar * grammar);
-struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar);
+struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
-void llama_grammar_sample_impl(
- const struct llama_grammar * grammar,
- const struct llama_vocab * vocab,
- const struct llama_sampling * smpl,
- llama_token_data_array * candidates);
+// TODO: move the API below as member functions of llama_grammar
+void llama_grammar_apply_impl(
+ const struct llama_grammar & grammar,
+ llama_token_data_array * cur_p);
-void llama_grammar_accept_token_impl(
- struct llama_grammar * grammar,
- const struct llama_vocab * vocab,
- const struct llama_sampling * smpl,
+void llama_grammar_accept_impl(
+ struct llama_grammar & grammar,
llama_token token);
#pragma once
-#define LLAMA_API_INTERNAL
#include "llama.h"
+#include <string>
+#include <vector>
+#include <stdexcept>
+
#ifdef __GNUC__
#ifdef __MINGW32__
#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
// helpers
//
+struct time_meas {
+ time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
+
+ ~time_meas() {
+ if (t_start_us >= 0) {
+ t_acc += ggml_time_us() - t_start_us;
+ }
+ }
+
+ const int64_t t_start_us;
+
+ int64_t & t_acc;
+};
+
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
if (search.empty()) {
return;
builder.append(s, last_pos, std::string::npos);
s = std::move(builder);
}
+
+const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
+ struct llama_context * ctx
+);
+
+// the ring buffer works similarly to std::deque, but with a fixed capacity
+template<typename T>
+struct ring_buffer {
+ ring_buffer(size_t cap) : capacity(cap), data(cap) {}
+
+ T & front() {
+ if (sz == 0) {
+ throw std::runtime_error("ring buffer is empty");
+ }
+ return data[first];
+ }
+
+ const T & front() const {
+ if (sz == 0) {
+ throw std::runtime_error("ring buffer is empty");
+ }
+ return data[first];
+ }
+
+ T & back() {
+ if (sz == 0) {
+ throw std::runtime_error("ring buffer is empty");
+ }
+ return data[pos];
+ }
+
+ const T & back() const {
+ if (sz == 0) {
+ throw std::runtime_error("ring buffer is empty");
+ }
+ return data[pos];
+ }
+
+ void push_back(const T & value) {
+ if (sz == capacity) {
+ // advance the start when buffer is full
+ first = (first + 1) % capacity;
+ } else {
+ sz++;
+ }
+ data[pos] = value;
+ pos = (pos + 1) % capacity;
+ }
+
+ T pop_front() {
+ if (sz == 0) {
+ throw std::runtime_error("ring buffer is empty");
+ }
+ T value = data[first];
+ first = (first + 1) % capacity;
+ sz--;
+ return value;
+ }
+
+ //T & operator[](size_t i) {
+ // if (i >= sz) {
+ // throw std::runtime_error("ring buffer: index out of bounds");
+ // }
+ // return data[(first + i) % capacity];
+ //}
+
+ //const T & at(size_t i) const {
+ // if (i >= sz) {
+ // throw std::runtime_error("ring buffer: index out of bounds");
+ // }
+ // return data[(first + i) % capacity];
+ //}
+
+ const T & rat(size_t i) const {
+ if (i >= sz) {
+ throw std::runtime_error("ring buffer: index out of bounds");
+ }
+ return data[(first + sz - i - 1) % capacity];
+ }
+
+ std::vector<T> to_vector() const {
+ std::vector<T> result;
+ result.reserve(sz);
+ for (size_t i = 0; i < sz; i++) {
+ result.push_back(data[(first + i) % capacity]);
+ }
+ return result;
+ }
+
+ void clear() {
+ // here only reset the status of the buffer
+ sz = 0;
+ first = 0;
+ pos = 0;
+ }
+
+ bool empty() const {
+ return sz == 0;
+ }
+
+ size_t size() const {
+ return sz;
+ }
+
+ size_t capacity = 0;
+ size_t sz = 0;
+ size_t first = 0;
+ size_t pos = 0;
+ std::vector<T> data;
+};
#include "llama-sampling.h"
+#include "llama-vocab.h"
+#include "llama-grammar.h"
+
+#include <cassert>
#include <algorithm>
#include <cstring>
#include <ctime>
#include <cfloat>
#include <numeric>
+#include <random>
#include <unordered_map>
+static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector<float> & probs) {
+ probs.resize(cur_p->size);
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ probs[i] = cur_p->data[i].p;
+ }
+
+ std::discrete_distribution<size_t> dist(probs.begin(), probs.end());
+
+ return dist(rng);
+}
+
static void llama_log_softmax(float * array, size_t size) {
float max_l = *std::max_element(array, array + size);
float sum = 0.f;
}
}
-void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
- if (seed == LLAMA_DEFAULT_SEED) {
- seed = time(NULL);
- }
-
- smpl->rng.seed(seed);
-}
-
-void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
- GGML_ASSERT(candidates->size > 0);
-
- const int64_t t_start_sample_us = ggml_time_us();
+static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
+ GGML_ASSERT(cur_p->size > 0);
// Sort the logits in descending order
- if (!candidates->sorted) {
- std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
+ if (!cur_p->sorted) {
+ std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit;
});
- candidates->sorted = true;
+ cur_p->sorted = true;
}
- float max_l = candidates->data[0].logit;
+ float max_l = cur_p->data[0].logit;
float cum_sum = 0.0f;
- for (size_t i = 0; i < candidates->size; ++i) {
- float p = expf(candidates->data[i].logit - max_l);
- candidates->data[i].p = p;
+
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ float p = expf(cur_p->data[i].logit - max_l);
+ cur_p->data[i].p = p;
cum_sum += p;
}
- for (size_t i = 0; i < candidates->size; ++i) {
- candidates->data[i].p /= cum_sum;
- }
- if (smpl) {
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ cur_p->data[i].p /= cum_sum;
}
}
-void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
+static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
- // if (k >= (int32_t)candidates->size) {
+ // if (k >= (int32_t)cur_p->size) {
// return;
// }
- const int64_t t_start_sample_us = ggml_time_us();
-
if (k <= 0) {
- k = candidates->size;
+ k = cur_p->size;
}
- k = std::max(k, (int) min_keep);
- k = std::min(k, (int) candidates->size);
+ k = std::min(k, (int) cur_p->size);
// Sort scores in descending order
- if (!candidates->sorted) {
+ if (!cur_p->sorted) {
auto comp = [](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit;
};
if (k <= 128) {
- std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
+ std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
} else {
constexpr int nbuckets = 128;
constexpr float bucket_low = -10.0f;
constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
constexpr float bucket_inter = -bucket_low * bucket_scale;
- std::vector<int> bucket_idx(candidates->size);
+ std::vector<int> bucket_idx(cur_p->size);
std::vector<int> histo(nbuckets, 0);
- for (int i = 0; i < (int)candidates->size; ++i) {
- const float val = candidates->data[i].logit;
+ for (int i = 0; i < (int)cur_p->size; ++i) {
+ const float val = cur_p->data[i].logit;
int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
ib = std::max(0, std::min(nbuckets-1, ib));
bucket_idx[i] = ib;
int ib = nbuckets - 1;
for ( ; ib >= 0; --ib) {
nhave += histo[ib];
- if (nhave >= k) break;
+ if (nhave >= k) {
+ break;
+ }
}
std::vector<llama_token_data> tmp_tokens(nhave);
- auto ptr = tmp_tokens.data();
+ auto * ptr = tmp_tokens.data();
std::vector<llama_token_data*> bucket_ptrs;
bucket_ptrs.reserve(nbuckets - ib);
for (int j = nbuckets - 1; j >= ib; --j) {
bucket_ptrs.push_back(ptr);
ptr += histo[j];
}
- for (int i = 0; i < (int)candidates->size; ++i) {
+ for (int i = 0; i < (int)cur_p->size; ++i) {
int j = bucket_idx[i];
if (j >= ib) {
- *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
+ *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
}
}
}
std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
- std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
+ std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
}
- candidates->sorted = true;
- }
- candidates->size = k;
-
- if (smpl) {
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+ cur_p->sorted = true;
}
+ cur_p->size = k;
}
-void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
+static void llama_sampler_top_p_impl(llama_token_data_array * cur_p, float p, size_t min_keep) {
if (p >= 1.0f) {
return;
}
- llama_sample_softmax_impl(smpl, candidates);
-
- const int64_t t_start_sample_us = ggml_time_us();
+ llama_sampler_softmax_impl(cur_p);
// Compute the cumulative probabilities
float cum_sum = 0.0f;
- size_t last_idx = candidates->size;
+ size_t last_idx = cur_p->size;
- for (size_t i = 0; i < candidates->size; ++i) {
- cum_sum += candidates->data[i].p;
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ cum_sum += cur_p->data[i].p;
// Check if the running sum is at least p or if we have kept at least min_keep tokens
// we set the last index to i+1 to indicate that the current iterate should be included in the set
}
// Resize the output vector to keep only the top-p tokens
- candidates->size = last_idx;
-
- if (smpl) {
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
- }
+ cur_p->size = last_idx;
}
-void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
- if (p <= 0.0f || !candidates->size) {
+static void llama_sampler_min_p_impl(llama_token_data_array * cur_p, float p, size_t min_keep) {
+ if (p <= 0.0f || !cur_p->size) {
return;
}
- const int64_t t_start_sample_us = ggml_time_us();
-
bool min_p_applied = false;
- // if the candidates aren't sorted, try the unsorted implementation first
- if (!candidates->sorted) {
+ // if the cur_p aren't sorted, try the unsorted implementation first
+ if (!cur_p->sorted) {
std::vector<llama_token_data> filtered_tokens;
float max_logit = -FLT_MAX;
- for (size_t i = 0; i < candidates->size; ++i) {
- max_logit = std::max(max_logit, candidates->data[i].logit);
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ max_logit = std::max(max_logit, cur_p->data[i].logit);
}
const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
- for (size_t i = 0; i < candidates->size; ++i) {
- if (candidates->data[i].logit >= min_logit) {
- filtered_tokens.push_back(candidates->data[i]);
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ if (cur_p->data[i].logit >= min_logit) {
+ filtered_tokens.push_back(cur_p->data[i]);
}
}
// if we have enough values the operation was a success
if (filtered_tokens.size() >= min_keep) {
- memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
- candidates->size = filtered_tokens.size();
+ memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
+ cur_p->size = filtered_tokens.size();
min_p_applied = true;
}
}
- // if the candidates are sorted or the unsorted implementation failed, use this implementation
+ // if the cur_p are sorted or the unsorted implementation failed, use this implementation
if (!min_p_applied) {
// Sort the logits in descending order
- if (!candidates->sorted) {
- std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
+ if (!cur_p->sorted) {
+ std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit;
});
- candidates->sorted = true;
+ cur_p->sorted = true;
}
- const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
+ const float min_logit = cur_p->data[0].logit + logf(p); // min logit for p_i >= p * p_max
size_t i = 1; // first token always matches
- for (; i < candidates->size; ++i) {
- if (candidates->data[i].logit < min_logit && i >= min_keep) {
+ for (; i < cur_p->size; ++i) {
+ if (cur_p->data[i].logit < min_logit && i >= min_keep) {
break; // prob too small
}
}
// Resize the output vector to keep only the matching tokens
- candidates->size = i;
- }
-
- if (smpl) {
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+ cur_p->size = i;
}
}
-void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
- if (z >= 1.0f || candidates->size <= 2) {
+static void llama_sampler_tail_free_impl(llama_token_data_array * cur_p, float z, size_t min_keep) {
+ if (z >= 1.0f || cur_p->size <= 2) {
return;
}
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
- const int64_t t_start_sample_us = ggml_time_us();
+ llama_sampler_softmax_impl(cur_p);
// Compute the first and second derivatives
- std::vector<float> first_derivatives(candidates->size - 1);
- std::vector<float> second_derivatives(candidates->size - 2);
+ std::vector<float> first_derivatives(cur_p->size - 1);
+ std::vector<float> second_derivatives(cur_p->size - 2);
for (size_t i = 0; i < first_derivatives.size(); ++i) {
- first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
+ first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
}
for (size_t i = 0; i < second_derivatives.size(); ++i) {
second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
}
float cum_sum = 0.0f;
- size_t last_idx = candidates->size;
+ size_t last_idx = cur_p->size;
for (size_t i = 0; i < second_derivatives.size(); ++i) {
cum_sum += second_derivatives[i];
}
// Resize the output vector to keep only the tokens above the tail location
- candidates->size = last_idx;
-
- if (smpl) {
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
- }
+ cur_p->size = last_idx;
}
-void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
+static void llama_sampler_typical_impl(llama_token_data_array * cur_p, float p, size_t min_keep) {
// Reference implementation:
// https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
if (p >= 1.0f) {
}
// Compute the softmax of logits and calculate entropy
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
-
- const int64_t t_start_sample_us = ggml_time_us();
+ llama_sampler_softmax_impl(cur_p);
float entropy = 0.0f;
- for (size_t i = 0; i < candidates->size; ++i) {
- entropy += -candidates->data[i].p * logf(candidates->data[i].p);
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
}
// Compute the absolute difference between negative log probability and entropy for each candidate
std::vector<float> shifted_scores;
- for (size_t i = 0; i < candidates->size; ++i) {
- float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
shifted_scores.push_back(shifted_score);
}
// Sort tokens based on the shifted_scores and their corresponding indices
- std::vector<size_t> indices(candidates->size);
+ std::vector<size_t> indices(cur_p->size);
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
for (size_t i = 0; i < indices.size(); ++i) {
size_t idx = indices[i];
- cum_sum += candidates->data[idx].p;
+ cum_sum += cur_p->data[idx].p;
// Check if the running sum is greater than typical or if we have kept at least min_keep tokens
if (cum_sum > p && i >= min_keep - 1) {
}
// Resize the output vector to keep only the locally typical tokens
- std::vector<llama_token_data> new_candidates;
+ std::vector<llama_token_data> cur_p_new;
for (size_t i = 0; i < last_idx; ++i) {
size_t idx = indices[i];
- new_candidates.push_back(candidates->data[idx]);
+ cur_p_new.push_back(cur_p->data[idx]);
}
- // Replace the data in candidates with the new_candidates data
- std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
- candidates->size = new_candidates.size();
- candidates->sorted = false;
-
- if (smpl) {
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
- }
+ // Replace the data in cur_p with the cur_p_new data
+ std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
+ cur_p->size = cur_p_new.size();
+ cur_p->sorted = false;
}
-void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
- const int64_t t_start_sample_us = ggml_time_us();
-
+static void llama_sampler_entropy_impl(llama_token_data_array * cur_p, float min_temp, float max_temp, float exponent_val) {
// no need to do anything if there is only one (or zero) candidates
- if(candidates->size <= 1) {
+ if (cur_p->size <= 1) {
return;
}
// Calculate maximum possible entropy
- float max_entropy = -logf(1.0f / candidates->size);
+ float max_entropy = -logf(1.0f / cur_p->size);
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
+ llama_sampler_softmax_impl(cur_p);
// Calculate entropy of the softmax probabilities
float entropy = 0.0f;
- for (size_t i = 0; i < candidates->size; ++i) {
- float prob = candidates->data[i].p;
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ float prob = cur_p->data[i].p;
if (prob > 0.0f) { // Ensure no log(0)
entropy -= prob * logf(prob);
}
}
- // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above)
+ // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
float normalized_entropy = entropy / max_entropy;
// Map the normalized entropy to the desired temperature range using the power function
#endif
// Apply the dynamically calculated temperature scaling
- for (size_t i = 0; i < candidates->size; ++i) {
- candidates->data[i].logit /= dyn_temp;
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ cur_p->data[i].logit /= dyn_temp;
}
// Re-compute softmax probabilities after scaling logits with dynamic temperature
- double max_l_double = candidates->data[0].logit;
+ const double max_l_double = cur_p->data[0].logit;
+
double cum_sum_double = 0.0;
- for (size_t i = 0; i < candidates->size; ++i) {
- double p = exp(candidates->data[i].logit - max_l_double);
- candidates->data[i].p = p; // Store the scaled probability
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ double p = exp(cur_p->data[i].logit - max_l_double);
+ cur_p->data[i].p = p; // Store the scaled probability
cum_sum_double += p;
}
- for (size_t i = 0; i < candidates->size; ++i) {
- candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
+
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
}
#ifdef DEBUG
// Print the updated top 25 probabilities after temperature scaling
LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
- for (size_t i = 0; i < 25 && i < candidates->size; ++i) {
- LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
+ for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
+ LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
}
#endif
-
- if (smpl) {
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
- }
}
-void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
- const int64_t t_start_sample_us = ggml_time_us();
-
- for (size_t i = 0; i < candidates->size; ++i) {
- candidates->data[i].logit /= temp;
- }
-
- if (smpl) {
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ cur_p->data[i].logit /= temp;
}
}
-void llama_sample_repetition_penalties_impl(
- struct llama_sampling * smpl,
- llama_token_data_array * candidates,
- const llama_token * last_tokens,
- size_t penalty_last_n,
- float penalty_repeat,
- float penalty_freq,
- float penalty_present) {
- if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
- return;
- }
-
- const int64_t t_start_sample_us = ggml_time_us();
-
- // Create a frequency map to count occurrences of each token in last_tokens
- std::unordered_map<llama_token, int> token_count;
- for (size_t i = 0; i < penalty_last_n; ++i) {
- token_count[last_tokens[i]]++;
- }
+static void llama_sampler_grammar_impl(llama_token_data_array * cur_p, const struct llama_grammar & grammar) {
+ llama_grammar_apply_impl(grammar, cur_p);
+}
- // Apply frequency and presence penalties to the candidates
- for (size_t i = 0; i < candidates->size; ++i) {
- const auto token_iter = token_count.find(candidates->data[i].id);
+void llama_sampler_penalties_impl(
+ llama_token_data_array * cur_p,
+ const llama_token_cnt & token_count,
+ float penalty_repeat,
+ float penalty_freq,
+ float penalty_present) {
+ // Apply frequency and presence penalties to the cur_p
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ const auto token_iter = token_count.find(cur_p->data[i].id);
if (token_iter == token_count.end()) {
continue;
}
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
- if (candidates->data[i].logit <= 0) {
- candidates->data[i].logit *= penalty_repeat;
+ if (cur_p->data[i].logit <= 0) {
+ cur_p->data[i].logit *= penalty_repeat;
} else {
- candidates->data[i].logit /= penalty_repeat;
+ cur_p->data[i].logit /= penalty_repeat;
}
- candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
+ cur_p->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
+ }
+
+ cur_p->sorted = false;
+}
+
+// llama_sampler API
+
+const char * llama_sampler_name(const struct llama_sampler * smpl) {
+ if (!smpl->iface) {
+ return "(null)";
+ }
+
+ return smpl->iface->name(smpl);
+}
+
+void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
+ if (smpl->iface->accept) {
+ smpl->iface->accept(smpl, token);
+ }
+}
+
+void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
+ GGML_ASSERT(smpl->iface->apply);
+ smpl->iface->apply(smpl, cur_p);
+}
+
+void llama_sampler_reset(struct llama_sampler * smpl) {
+ if (smpl->iface->reset) {
+ smpl->iface->reset(smpl);
}
+}
- candidates->sorted = false;
+struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
+ if (smpl->iface->clone) {
+ return smpl->iface->clone(smpl);
+ }
- if (smpl) {
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+ if (smpl->ctx == nullptr) {
+ return new llama_sampler {
+ /* .iface = */ smpl->iface,
+ /* .ctx = */ nullptr,
+ };
}
+
+ GGML_ABORT("the sampler does not support cloning");
}
-void llama_sample_apply_guidance_impl(
- struct llama_sampling * smpl,
- float * logits,
- float * logits_guidance,
- float scale) {
- GGML_ASSERT(smpl);
+void llama_sampler_free(struct llama_sampler * smpl) {
+ if (smpl == nullptr) {
+ return;
+ }
+
+ if (smpl->iface->free) {
+ smpl->iface->free(smpl);
+ }
- const auto t_start_sample_us = ggml_time_us();
- const auto n_vocab = smpl->n_vocab;
+ delete smpl;
+}
- llama_log_softmax(logits, n_vocab);
- llama_log_softmax(logits_guidance, n_vocab);
+llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
+ const auto * logits = llama_get_logits_ith(ctx, idx);
- for (int i = 0; i < n_vocab; ++i) {
- auto & l = logits[i];
- const auto & g = logits_guidance[i];
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx));
- l = scale * (l - g) + g;
+ // TODO: do not allocate each time
+ std::vector<llama_token_data> cur(n_vocab);
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+ llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+
+ llama_sampler_apply(smpl, &cur_p);
+
+ return cur_p.data[cur_p.selected].id;
}
-llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
- GGML_ASSERT(smpl);
+// sampler chain
+
+static struct llama_sampler_i llama_sampler_chain_i = {
+ /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; },
+ /* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+ time_meas tm(chain->t_sample_us, chain->params.no_perf);
+
+ for (auto * smpl : chain->samplers) {
+ llama_sampler_accept(smpl, token);
+ }
+
+ chain->n_sample++;
+ },
+ /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+ time_meas tm(chain->t_sample_us, chain->params.no_perf);
+
+ for (auto * smpl : chain->samplers) {
+ llama_sampler_apply(smpl, cur_p);
+ }
+ },
+ /* .reset = */ [](struct llama_sampler * smpl) {
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+ for (auto * smpl : chain->samplers) {
+ llama_sampler_reset(smpl);
+ }
+
+ chain->t_sample_us = 0;
+ chain->n_sample = 0;
+ },
+ /* .clone = */ [](const struct llama_sampler * smpl) {
+ const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
- const int32_t n_vocab = float(smpl->n_vocab);
+ auto * result = llama_sampler_chain_init(chain_src->params);
- int64_t t_start_sample_us = ggml_time_us();
+ for (auto * smpl : chain_src->samplers) {
+ llama_sampler_chain_add(result, llama_sampler_clone(smpl));
+ }
+
+ return result;
+ },
+ /* .free = */ [](struct llama_sampler * smpl) {
+ auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+ for (auto * smpl : chain->samplers) {
+ llama_sampler_free(smpl);
+ }
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
+ delete chain;
+ },
+};
+
+struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
+ return new llama_sampler {
+ /* .iface = */ &llama_sampler_chain_i,
+ /* .ctx = */ new llama_sampler_chain {
+ /* .params = */ params,
+ /* .samplers = */ {},
+ /* .t_sample_us = */ 0,
+ /* .n_sample = */ 0,
+ },
+ };
+}
+
+void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
+ auto * p = (llama_sampler_chain *) chain->ctx;
+ p->samplers.push_back(smpl);
+}
- // Estimate s_hat using the most probable m tokens
- float s_hat = 0.0;
- float sum_ti_bi = 0.0;
- float sum_ti_sq = 0.0;
- for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
- float t_i = logf(float(i + 2) / float(i + 1));
- float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
- sum_ti_bi += t_i * b_i;
- sum_ti_sq += t_i * t_i;
+struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
+ const auto * p = (const llama_sampler_chain *) chain->ctx;
+
+ if (i < 0 || i >= (int32_t) p->samplers.size()) {
+ return nullptr;
}
- s_hat = sum_ti_bi / sum_ti_sq;
- // Compute k from the estimated s_hat and target surprise value
- float epsilon_hat = s_hat - 1;
- float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
+ return p->samplers[i];
+}
+
+int llama_sampler_chain_n(const struct llama_sampler * chain) {
+ const auto * p = (const llama_sampler_chain *) chain->ctx;
- // Sample the next word X using top-k sampling
- llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
- llama_token X = llama_sample_token_impl(smpl, candidates);
- t_start_sample_us = ggml_time_us();
+ return p->samplers.size();
+}
- // Compute error as the difference between observed surprise and target surprise value
- size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
- return candidate.id == X;
- }));
- float observed_surprise = -log2f(candidates->data[X_idx].p);
- float e = observed_surprise - tau;
+//
+// samplers
+//
- // Update mu using the learning rate and error
- *mu = *mu - eta * e;
+// greedy
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
- return X;
+static struct llama_sampler_i llama_sampler_greedy_i = {
+ /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "greedy"; },
+ /* .accept = */ nullptr,
+ /* .apply = */ [](struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
+ cur_p->selected = 0;
+ for (size_t i = 1; i < cur_p->size; ++i) {
+ if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
+ cur_p->selected = i;
+ }
+ }
+ },
+ /* .reset = */ nullptr,
+ /* .clone = */ nullptr,
+ /* .free = */ nullptr,
+};
+
+struct llama_sampler * llama_sampler_init_greedy() {
+ return new llama_sampler {
+ /* .iface = */ &llama_sampler_greedy_i,
+ /* .ctx = */ nullptr,
+ };
}
-llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
- int64_t t_start_sample_us;
- t_start_sample_us = ggml_time_us();
+// dist
- llama_sample_softmax_impl(smpl, candidates);
+struct llama_sampler_dist {
+ const uint32_t seed;
- // Truncate the words with surprise values greater than mu
- candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
- return -log2f(candidate.p) > *mu;
- }));
+ std::mt19937 rng;
- if (candidates->size == 0) {
- candidates->size = 1;
- }
+ std::vector<float> probs; // work array
+};
- if (smpl) {
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
- }
+static struct llama_sampler_i llama_sampler_dist_i = {
+ /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "dist"; },
+ /* .accept = */ nullptr,
+ /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+ auto * ctx = (llama_sampler_dist *) smpl->ctx;
+ cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
+ },
+ /* .reset = */ nullptr,
+ /* .clone = */ [](const struct llama_sampler * smpl) {
+ const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
+ auto * result = llama_sampler_init_dist(ctx->seed);
- // Normalize the probabilities of the remaining words
- llama_sample_softmax_impl(smpl, candidates);
+ // copy the state
+ {
+ auto * result_ctx = (llama_sampler_dist *) result->ctx;
- // Sample the next word X from the remaining words
- llama_token X = llama_sample_token_impl(smpl, candidates);
- t_start_sample_us = ggml_time_us();
+ result_ctx->rng = ctx->rng;
+ }
- // Compute error as the difference between observed surprise and target surprise value
- size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
- return candidate.id == X;
- }));
- float observed_surprise = -log2f(candidates->data[X_idx].p);
- float e = observed_surprise - tau;
+ return result;
+ },
+ /* .free = */ [](struct llama_sampler * smpl) {
+ delete (llama_sampler_dist *) smpl->ctx;
+ },
+};
+
+struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
+ return new llama_sampler {
+ /* .iface = */ &llama_sampler_dist_i,
+ /* .ctx = */ new llama_sampler_dist {
+ /* .seed = */ seed,
+ /* .rng = */ std::mt19937(seed),
+ /* .probs = */ {},
+ },
+ };
+}
- // Update mu using the learning rate and error
- *mu = *mu - eta * e;
+// softmax
+
+static struct llama_sampler_i llama_sampler_softmax_i = {
+ /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "softmax"; },
+ /* .accept = */ nullptr,
+ /* .apply = */ [](struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
+ llama_sampler_softmax_impl(cur_p);
+ },
+ /* .reset = */ nullptr,
+ /* .clone = */ nullptr,
+ /* .free = */ nullptr,
+};
+
+struct llama_sampler * llama_sampler_init_softmax() {
+ return new llama_sampler {
+ /* .iface = */ &llama_sampler_softmax_i,
+ /* .ctx = */ nullptr,
+ };
+}
- if (smpl) {
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
- }
- return X;
+// top-k
+
+struct llama_sampler_top_k {
+ const int32_t k;
+};
+
+static struct llama_sampler_i llama_sampler_top_k_i = {
+ /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "top-k"; },
+ /* .accept = */ nullptr,
+ /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+ const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
+ llama_sampler_top_k_impl(cur_p, ctx->k);
+ },
+ /* .reset = */ nullptr,
+ /* .clone = */ [](const struct llama_sampler * smpl) {
+ const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
+ return llama_sampler_init_top_k(ctx->k);
+ },
+ /* .free = */ [](struct llama_sampler * smpl) {
+ delete (llama_sampler_top_k *) smpl->ctx;
+ },
+};
+
+struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
+ return new llama_sampler {
+ /* .iface = */ &llama_sampler_top_k_i,
+ /* .ctx = */ new llama_sampler_top_k {
+ /* .k = */ k,
+ },
+ };
}
-llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
- const int64_t t_start_sample_us = ggml_time_us();
+// top-p
+
+struct llama_sampler_top_p {
+ const float p;
+ const size_t min_keep;
+};
+
+static struct llama_sampler_i llama_sampler_top_p_i = {
+ /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "top-p"; },
+ /* .accept = */ nullptr,
+ /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+ const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
+ llama_sampler_top_p_impl(cur_p, ctx->p, ctx->min_keep);
+ },
+ /* .reset = */ nullptr,
+ /* .clone = */ [](const struct llama_sampler * smpl) {
+ const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
+ return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
+ },
+ /* .free = */ [](struct llama_sampler * smpl) {
+ delete (llama_sampler_top_p *) smpl->ctx;
+ },
+};
+
+struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
+ return new llama_sampler {
+ /* .iface = */ &llama_sampler_top_p_i,
+ /* .ctx = */ new llama_sampler_top_p {
+ /* .p = */ p,
+ /* .min_keep = */ min_keep,
+ },
+ };
+}
- // Find max element
- auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
- return a.logit < b.logit;
- });
+// min-p
+
+struct llama_sampler_min_p {
+ const float p;
+ const size_t min_keep;
+};
+
+static struct llama_sampler_i llama_sampler_min_p_i = {
+ /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "min-p"; },
+ /* .accept = */ nullptr,
+ /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+ const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
+ llama_sampler_min_p_impl(cur_p, ctx->p, ctx->min_keep);
+ },
+ /* .reset = */ nullptr,
+ /* .clone = */ [](const struct llama_sampler * smpl) {
+ const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
+ return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
+ },
+ /* .free = */ [](struct llama_sampler * smpl) {
+ delete (llama_sampler_min_p *) smpl->ctx;
+ },
+};
+
+struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
+ return new llama_sampler {
+ /* .iface = */ &llama_sampler_min_p_i,
+ /* .ctx = */ new llama_sampler_min_p {
+ /* .p = */ p,
+ /* .min_keep = */ min_keep,
+ },
+ };
+}
- llama_token result = max_iter->id;
- if (smpl) {
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
- smpl->n_sample++;
- }
- return result;
+// tail-free
+
+struct llama_sampler_tail_free {
+ const float z;
+ const size_t min_keep;
+};
+
+static struct llama_sampler_i llama_sampler_tail_free_i = {
+ /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "tail-free"; },
+ /* .accept = */ nullptr,
+ /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+ const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
+ llama_sampler_tail_free_impl(cur_p, ctx->z, ctx->min_keep);
+ },
+ /* .reset = */ nullptr,
+ /* .clone = */ [](const struct llama_sampler * smpl) {
+ const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
+ return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
+ },
+ /* .free = */ [](struct llama_sampler * smpl) {
+ delete (llama_sampler_tail_free *) smpl->ctx;
+ },
+};
+
+struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
+ return new llama_sampler {
+ /* .iface = */ &llama_sampler_tail_free_i,
+ /* .ctx = */ new llama_sampler_tail_free {
+ /* .z = */ z,
+ /*. min_keep = */ min_keep,
+ },
+ };
}
-llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
- GGML_ASSERT(smpl);
+// typical
+
+struct llama_sampler_typical {
+ const float p;
+ const size_t min_keep;
+};
+
+static struct llama_sampler_i llama_sampler_typical_i = {
+ /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "typical"; },
+ /* .accept = */ nullptr,
+ /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+ const auto * ctx = (llama_sampler_typical *) smpl->ctx;
+ llama_sampler_typical_impl(cur_p, ctx->p, ctx->min_keep);
+ },
+ /* .reset = */ nullptr,
+ /* .clone = */ [](const struct llama_sampler * smpl) {
+ const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
+ return llama_sampler_init_typical(ctx->p, ctx->min_keep);
+ },
+ /* .free = */ [](struct llama_sampler * smpl) {
+ delete (llama_sampler_typical *) smpl->ctx;
+ },
+};
+
+struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
+ return new llama_sampler {
+ /* .iface = */ &llama_sampler_typical_i,
+ /* .ctx = */ new llama_sampler_typical {
+ /* .p = */ p,
+ /* .min_keep = */ min_keep,
+ },
+ };
+}
+
+// temp
+
+struct llama_sampler_temp {
+ const float temp;
+};
+
+static struct llama_sampler_i llama_sampler_temp_i = {
+ /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "temp"; },
+ /* .accept = */ nullptr,
+ /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+ const auto * ctx = (llama_sampler_temp *) smpl->ctx;
+ llama_sampler_temp_impl(cur_p, ctx->temp);
+ },
+ /* .reset = */ nullptr,
+ /* .clone = */ [](const struct llama_sampler * smpl) {
+ const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
+ return llama_sampler_init_temp(ctx->temp);
+ },
+ /* .free = */ [](struct llama_sampler * smpl) {
+ delete (llama_sampler_temp *) smpl->ctx;
+ },
+};
+
+struct llama_sampler * llama_sampler_init_temp(float temp) {
+ return new llama_sampler {
+ /* .iface = */ &llama_sampler_temp_i,
+ /* .ctx = */ new llama_sampler_temp {
+ /*.temp = */ temp,
+ },
+ };
+}
- const int64_t t_start_sample_us = ggml_time_us();
- llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
+// temp-ext
+
+struct llama_sampler_temp_ext {
+ const float temp;
+ const float delta;
+ const float exponent;
+};
+
+static struct llama_sampler_i llama_sampler_temp_ext_i = {
+ /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "temp-ext"; },
+ /* .accept = */ nullptr,
+ /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+ const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
+ if (ctx->delta > 0) {
+ const float temp_min = std::max(0.0f, ctx->temp - ctx->delta);
+ const float temp_max = ctx->temp + ctx->delta;
+
+ llama_sampler_entropy_impl(cur_p, temp_min, temp_max, ctx->exponent);
+ } else {
+ llama_sampler_temp_impl(cur_p, ctx->temp);
+ }
+ },
+ /* .reset = */ nullptr,
+ /* .clone = */ [](const struct llama_sampler * smpl) {
+ const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
+ return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
+ },
+ /* .free = */ [](struct llama_sampler * smpl) {
+ delete (llama_sampler_temp_ext *) smpl->ctx;
+ },
+};
+
+struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
+ return new llama_sampler {
+ /* .iface = */ &llama_sampler_temp_ext_i,
+ /* .ctx = */ new llama_sampler_temp_ext {
+ /* .temp = */ temp,
+ /* .delta = */ delta,
+ /* .exponent = */ exponent,
+ },
+ };
+}
+
+// mirostat
+
+struct llama_sampler_mirostat {
+ const int32_t n_vocab;
+
+ const uint32_t seed;
+
+ const float tau;
+ const float eta;
+
+ const int32_t m;
+
+ float mu;
+
+ std::mt19937 rng;
+
+ std::vector<float> probs;
+};
+
+static struct llama_sampler_i llama_sampler_mirostat_i = {
+ /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat"; },
+ /* .accept = */ nullptr,
+ /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+ auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
+
+ llama_sampler_softmax_impl(cur_p);
+
+ // Estimate s_hat using the most probable m tokens
+ float s_hat = 0.0;
+ float sum_ti_bi = 0.0;
+ float sum_ti_sq = 0.0;
+ for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
+ float t_i = logf(float(i + 2) / float(i + 1));
+ float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
+ sum_ti_bi += t_i * b_i;
+ sum_ti_sq += t_i * t_i;
+ }
+ s_hat = sum_ti_bi / sum_ti_sq;
+
+ // Compute k from the estimated s_hat and target surprise value
+ float epsilon_hat = s_hat - 1;
+ float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
+
+ llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
+ llama_sampler_softmax_impl(cur_p);
+
+ const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
+
+ cur_p->selected = idx;
+
+ float observed_surprise = -log2f(cur_p->data[idx].p);
+ float e = observed_surprise - ctx->tau;
+
+ // Update mu using the learning rate and error
+ ctx->mu = ctx->mu - ctx->eta * e;
+ },
+ /* .reset = */ [](struct llama_sampler * smpl) {
+ auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
+ ctx->mu = 2.0f*ctx->tau;
+ ctx->rng = std::mt19937(ctx->seed);
+ },
+ /* .clone = */ [](const struct llama_sampler * smpl) {
+ const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
+ auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
+
+ // copy the state
+ {
+ auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
+
+ result_ctx->mu = ctx->mu;
+ result_ctx->rng = ctx->rng;
+ }
+
+ return result;
+ },
+ /* .free = */ [](struct llama_sampler * smpl) {
+ delete (llama_sampler_mirostat *) smpl->ctx;
+ },
+};
+
+struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
+ return new llama_sampler {
+ /* .iface = */ &llama_sampler_mirostat_i,
+ /* .ctx = */ new llama_sampler_mirostat {
+ /* .n_vocab = */ n_vocab,
+ /* .seed = */ seed,
+ /* .tau = */ tau,
+ /* .eta = */ eta,
+ /* .m = */ m,
+ /* .mu = */ 2.0f*tau,
+ /* .rng = */ std::mt19937(seed),
+ /* .probs = */ {},
+ },
+ };
+}
+
+// mirostat v2
+
+struct llama_sampler_mirostat_v2 {
+ const uint32_t seed;
+
+ const float tau;
+ const float eta;
+
+ float mu;
+
+ std::mt19937 rng;
std::vector<float> probs;
- probs.reserve(candidates->size);
- for (size_t i = 0; i < candidates->size; ++i) {
- probs.push_back(candidates->data[i].p);
+};
+
+static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
+ /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat-v2"; },
+ /* .accept = */ nullptr,
+ /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+ auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
+
+ llama_sampler_softmax_impl(cur_p);
+
+ // Truncate the words with surprise values greater than mu
+ cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
+ return -log2f(candidate.p) > ctx->mu;
+ }));
+
+ if (cur_p->size == 0) {
+ cur_p->size = 1;
+ }
+
+ // Normalize the probabilities of the remaining words
+ llama_sampler_softmax_impl(cur_p);
+
+ const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
+
+ cur_p->selected = idx;
+
+ float observed_surprise = -log2f(cur_p->data[idx].p);
+ float e = observed_surprise - ctx->tau;
+
+ // Update mu using the learning rate and error
+ ctx->mu = ctx->mu - ctx->eta * e;
+ },
+ /* .reset = */ [](struct llama_sampler * smpl) {
+ auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
+ ctx->mu = 2.0f*ctx->tau;
+ ctx->rng = std::mt19937(ctx->seed);
+ },
+ /* .clone = */ [](const struct llama_sampler * smpl) {
+ const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
+
+ auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
+
+ // copy the state
+ {
+ auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
+
+ result_ctx->mu = ctx->mu;
+ result_ctx->rng = ctx->rng;
+ }
+
+ return result;
+ },
+ /* .free = */ [](struct llama_sampler * smpl) {
+ delete (llama_sampler_mirostat_v2 *) smpl->ctx;
+ },
+};
+
+struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
+ return new llama_sampler {
+ /* .iface = */ &llama_sampler_mirostat_v2_i,
+ /* .ctx = */ new llama_sampler_mirostat_v2 {
+ /* .seed = */ seed,
+ /* .tau = */ tau,
+ /* .eta = */ eta,
+ /* .mu = */ 2.0f*tau,
+ /* .rng = */ std::mt19937(seed),
+ /* .probs = */ {},
+ },
+ };
+}
+
+// grammar
+
+struct llama_sampler_grammar {
+ const struct llama_vocab * vocab;
+
+ std::string grammar_str;
+ std::string grammar_root;
+
+ struct llama_grammar * grammar;
+};
+
+static struct llama_sampler_i llama_sampler_grammar_i = {
+ /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "grammar"; },
+ /* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
+ const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
+ if (ctx->grammar) {
+ llama_grammar_accept_impl(*ctx->grammar, token);
+ }
+ },
+ /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+ const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
+ if (ctx->grammar) {
+ llama_sampler_grammar_impl(cur_p, *ctx->grammar);
+ }
+ },
+ /* .reset = */ [](struct llama_sampler * smpl) {
+ auto * ctx = (llama_sampler_grammar *) smpl->ctx;
+ if (!ctx->grammar) {
+ return;
+ }
+
+ auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
+
+ llama_grammar_free_impl(ctx->grammar);
+ ctx->grammar = grammar_new;
+ },
+ /* .clone = */ [](const struct llama_sampler * smpl) {
+ const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
+
+ auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
+
+ // copy the state
+ {
+ auto * result_ctx = (llama_sampler_grammar *) result->ctx;
+
+ if (ctx->grammar) {
+ result_ctx->grammar_str = ctx->grammar_str;
+ result_ctx->grammar_root = ctx->grammar_root;
+
+ result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
+ }
+ }
+
+ return result;
+ },
+ /* .free = */ [](struct llama_sampler * smpl) {
+ const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
+
+ if (ctx->grammar) {
+ llama_grammar_free_impl(ctx->grammar);
+ }
+
+ delete ctx;
+ },
+};
+
+struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
+ auto * ctx = new llama_sampler_grammar;
+
+ if (grammar_str != nullptr && grammar_str[0] != '\0') {
+ *ctx = {
+ /* .vocab = */ &vocab,
+ /* .grammar_str = */ grammar_str,
+ /* .grammar_root = */ grammar_root,
+ /* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
+ };
+ } else {
+ *ctx = {
+ /* .vocab = */ &vocab,
+ /* .grammar_str = */ {},
+ /* .grammar_root = */ {},
+ /* .grammar = */ nullptr,
+ };
}
- std::discrete_distribution<> dist(probs.begin(), probs.end());
- int idx = dist(rng);
+ return new llama_sampler {
+ /* .iface = */ &llama_sampler_grammar_i,
+ /* .ctx = */ ctx,
+ };
+}
- llama_token result = candidates->data[idx].id;
+// penalties
+
+struct llama_sampler_penalties {
+ const int32_t n_vocab;
+ const llama_token special_eos_id;
+ const llama_token linefeed_id;
+
+ const int32_t penalty_last_n;
+ const float penalty_repeat;
+ const float penalty_freq;
+ const float penalty_present;
+
+ const bool penalize_nl;
+ const bool ignore_eos;
+
+ ring_buffer<llama_token> prev;
+};
+
+static struct llama_sampler_i llama_sampler_penalties_i = {
+ /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "penalties"; },
+ /* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
+ ctx->prev.push_back(token);
+ },
+ /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
+
+ if (ctx->ignore_eos) {
+ assert(ctx->special_eos_id >= 0);
+
+ // optimistically check if the candidates are not yet sorted/shuffled/truncated
+ if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
+ cur_p->data[ctx->special_eos_id].logit = -INFINITY;
+ } else {
+ // else, search for the special EOS token
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ if (cur_p->data[i].id == ctx->special_eos_id) {
+ cur_p->data[i].logit = -INFINITY;
+ break;
+ }
+ }
+ }
+ }
- smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
- smpl->n_sample++;
+ if ((ctx->penalty_last_n == 0) ||
+ (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
+ return;
+ }
- return result;
+ bool nl_found = false;
+ size_t nl_idx = 0;
+ float nl_logit = -INFINITY;
+ if (!ctx->penalize_nl) {
+ assert(ctx->linefeed_id >= 0);
+
+ // optimistically check if the candidates are not yet sorted/shuffled/truncated
+ if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
+ nl_found = true;
+ nl_idx = ctx->linefeed_id;
+ nl_logit = cur_p->data[ctx->linefeed_id].logit;
+ } else {
+ // else, search for the linefeed token
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ if (cur_p->data[i].id == ctx->linefeed_id) {
+ nl_found = true;
+ nl_idx = i;
+ nl_logit = cur_p->data[i].logit;
+ break;
+ }
+ }
+ }
+ }
+
+ // Create a frequency map to count occurrences of each token in last_tokens
+ // TODO: optimize this by maintaining the token count in the sampler context
+ llama_token_cnt token_count;
+ for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
+ token_count[ctx->prev.rat(i)]++;
+ }
+
+ llama_sampler_penalties_impl(cur_p, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present);
+
+ if (!ctx->penalize_nl && nl_found) {
+ // restore the logit of the newline token if it was penalized
+ cur_p->data[nl_idx].logit = nl_logit;
+ }
+ },
+ /* .reset = */ [](struct llama_sampler * smpl) {
+ auto * ctx = (llama_sampler_penalties *) smpl->ctx;
+ ctx->prev.clear();
+ },
+ /* .clone = */ [](const struct llama_sampler * smpl) {
+ const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
+ auto * result = llama_sampler_init_penalties(
+ ctx->n_vocab,
+ ctx->special_eos_id,
+ ctx->linefeed_id,
+ ctx->penalty_last_n,
+ ctx->penalty_repeat,
+ ctx->penalty_freq,
+ ctx->penalty_present,
+ ctx->penalize_nl,
+ ctx->ignore_eos);
+
+ // copy the state
+ {
+ auto * result_ctx = (llama_sampler_penalties *) result->ctx;
+
+ result_ctx->prev = ctx->prev;
+ }
+
+ return result;
+ },
+ /* .free = */ [](struct llama_sampler * smpl) {
+ delete (llama_sampler_penalties *) smpl->ctx;
+ },
+};
+
+struct llama_sampler * llama_sampler_init_penalties(
+ int32_t n_vocab,
+ llama_token special_eos_id,
+ llama_token linefeed_id,
+ int32_t penalty_last_n,
+ float penalty_repeat,
+ float penalty_freq,
+ float penalty_present,
+ bool penalize_nl,
+ bool ignore_eos) {
+ if (linefeed_id == LLAMA_TOKEN_NULL) {
+ penalize_nl = false;
+ }
+
+ if (special_eos_id == LLAMA_TOKEN_NULL) {
+ ignore_eos = true;
+ }
+
+ return new llama_sampler {
+ /* .iface = */ &llama_sampler_penalties_i,
+ /* .ctx = */ new llama_sampler_penalties {
+ /* .n_vocab = */ n_vocab,
+ /* .special_eos_id = */ special_eos_id,
+ /* .linefeed_id = */ linefeed_id,
+ /* .penalty_last_n = */ penalty_last_n,
+ /* .penalty_repeat = */ penalty_repeat,
+ /* .penalty_freq = */ penalty_freq,
+ /* .penalty_present = */ penalty_present,
+ /* .penalize_nl = */ penalize_nl,
+ /* .ignore_eos = */ ignore_eos,
+ /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
+ },
+ };
}
-llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
- return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
+// logit-bias
+
+struct llama_sampler_logit_bias {
+ const int32_t n_vocab;
+
+ const std::vector<llama_logit_bias> logit_bias;
+
+ std::vector<llama_logit_bias> to_search;
+};
+
+static struct llama_sampler_i llama_sampler_logit_bias_i = {
+ /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "logit-bias"; },
+ /* .accept = */ nullptr,
+ /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+ auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
+
+ ctx->to_search.clear();
+
+ // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
+ for (const auto & lb : ctx->logit_bias) {
+ if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
+ cur_p->data[lb.token].logit += lb.bias;
+ } else {
+ ctx->to_search.push_back(lb);
+ }
+ }
+
+ // search for the remaining candidates that were not found in the previous step
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ for (const auto & lb : ctx->to_search) {
+ if (cur_p->data[i].id == lb.token) {
+ cur_p->data[i].logit += lb.bias;
+ break;
+ }
+ }
+ }
+ },
+ /* .reset = */ nullptr,
+ /* .clone = */ [](const struct llama_sampler * smpl) {
+ const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
+ return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
+ },
+ /* .free = */ [](struct llama_sampler * smpl) {
+ delete (llama_sampler_logit_bias *) smpl->ctx;
+ },
+};
+
+struct llama_sampler * llama_sampler_init_logit_bias(
+ int32_t n_vocab,
+ int32_t n_logit_bias,
+ const llama_logit_bias * logit_bias) {
+ return new llama_sampler {
+ /* .iface = */ &llama_sampler_logit_bias_i,
+ /* .ctx = */ new llama_sampler_logit_bias {
+ /* .n_vocab = */ n_vocab,
+ /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
+ /* .to_search = */ {},
+ },
+ };
}
#pragma once
-#include "llama-impl.h"
+// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
-struct llama_sampling {
- llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
+#include "llama-grammar.h"
- std::mt19937 rng;
+#include <unordered_map>
- int32_t n_vocab = 0;
+struct llama_vocab;
+struct llama_grammar;
- mutable int64_t t_sample_us = 0;
- mutable int32_t n_sample = 0;
+// sampler chain
- void reset_timings() const {
- t_sample_us = 0;
- n_sample = 0;
- }
+struct llama_sampler_chain {
+ llama_sampler_chain_params params;
+
+ std::vector<struct llama_sampler *> samplers;
+
+ // timing
+
+ mutable int64_t t_sample_us;
+
+ mutable int32_t n_sample;
};
-//
-// internal API
-//
-
-void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed);
-
-void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
-void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
-void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
-void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
-void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
-void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
-void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
-void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
-
-void llama_sample_repetition_penalties_impl(
- struct llama_sampling * smpl,
- llama_token_data_array * candidates,
- const llama_token * last_tokens,
- size_t penalty_last_n,
+using llama_token_cnt = std::unordered_map<llama_token, int>;
+
+// TODO: tmp exposed until test-sampling is fixed
+void llama_sampler_penalties_impl(
+ llama_token_data_array * cur_p,
+ const llama_token_cnt & token_count,
float penalty_repeat,
float penalty_freq,
float penalty_present);
-void llama_sample_apply_guidance_impl(
- struct llama_sampling * smpl,
- float * logits,
- float * logits_guidance,
- float scale);
-
-llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
-llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
-llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
-llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
-llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
-
+struct llama_sampler * llama_sampler_init_grammar_impl(
+ const struct llama_vocab & vocab,
+ const char * grammar_str,
+ const char * grammar_root);
tattr attr;
};
+ uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
+
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
};
-const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
-
//
// internal API
//
bool add_special,
bool parse_special = false);
+// TODO: move the API below as member functions of llama_vocab
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
#include "llama-impl.h"
#include "llama-vocab.h"
-#include "llama-grammar.h"
#include "llama-sampling.h"
#include "unicode.h"
struct llama_context {
llama_context(const llama_model & model)
: model(model)
- , sampling(llama_n_vocab(&model))
, t_start_us(model.t_start_us)
, t_load_us(model.t_load_us) {}
const struct llama_model & model;
struct llama_cparams cparams;
- struct llama_sampling sampling;
struct llama_sbatch sbatch;
struct llama_kv_cache kv_self;
struct llama_control_vector cvec;
bool has_evaluated_once = false;
- int64_t t_start_us;
- int64_t t_load_us;
- int64_t t_p_eval_us = 0;
- int64_t t_eval_us = 0;
+ mutable int64_t t_start_us;
+ mutable int64_t t_load_us;
+ mutable int64_t t_p_eval_us = 0;
+ mutable int64_t t_eval_us = 0;
- int64_t t_compute_start_us = 0;
- int64_t n_queued_tokens = 0;
+ mutable int64_t t_compute_start_us = 0;
+ mutable int64_t n_queued_tokens = 0;
- int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
- int32_t n_eval = 0; // number of eval calls
+ mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
+ mutable int32_t n_eval = 0; // number of eval calls
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_t buf_output = nullptr;
const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
+ vocab.n_vocab = n_vocab;
vocab.id_to_token.resize(n_vocab);
for (uint32_t i = 0; i < n_vocab; i++) {
struct llama_context_params llama_context_default_params() {
struct llama_context_params result = {
- /*.seed =*/ LLAMA_DEFAULT_SEED,
/*.n_ctx =*/ 512,
/*.n_batch =*/ 2048,
/*.n_ubatch =*/ 512,
return result;
}
+struct llama_sampler_chain_params llama_sampler_chain_default_params() {
+ struct llama_sampler_chain_params result = {
+ /*.no_perf =*/ true,
+ };
+
+ return result;
+}
+
struct llama_model_quantize_params llama_model_quantize_default_params() {
struct llama_model_quantize_params result = {
/*.nthread =*/ 0,
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
}
- if (params.seed == LLAMA_DEFAULT_SEED) {
- params.seed = time(NULL);
- }
-
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch);
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
ctx->abort_callback = params.abort_callback;
ctx->abort_callback_data = params.abort_callback_data;
- ctx->sampling.rng = std::mt19937(params.seed);
- ctx->logits_all = params.logits_all;
+ ctx->logits_all = params.logits_all;
+
// build worst-case graph for encoder if a model contains encoder
- ctx->is_encoding = llama_model_has_encoder(model);
+ ctx->is_encoding = llama_model_has_encoder(model);
uint32_t kv_size = cparams.n_ctx;
ggml_type type_k = params.type_k;
delete ctx;
}
-const struct llama_model * llama_get_model(const struct llama_context * ctx) {
- return &ctx->model;
-}
-
-const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx) {
- return &ctx->model.vocab;
-}
-
uint32_t llama_n_ctx(const struct llama_context * ctx) {
return ctx->cparams.n_ctx;
}
return model->vocab.type;
}
+int32_t llama_n_vocab(const struct llama_model * model) {
+ return model->hparams.n_vocab;
+}
+
+int32_t llama_n_ctx_train(const struct llama_model * model) {
+ return model->hparams.n_ctx_train;
+}
+
+int32_t llama_n_embd(const struct llama_model * model) {
+ return model->hparams.n_embd;
+}
+
+int32_t llama_n_layer(const struct llama_model * model) {
+ return model->hparams.n_layer;
+}
+
+const struct llama_model * llama_get_model(const struct llama_context * ctx) {
+ return &ctx->model;
+}
+
+enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
+ return ctx->cparams.pooling_type;
+}
+
enum llama_rope_type llama_rope_type(const struct llama_model * model) {
switch (model->arch) {
// these models do not use RoPE
return LLAMA_ROPE_TYPE_NONE;
}
-enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
- return ctx->cparams.pooling_type;
-}
-
-int32_t llama_n_vocab(const struct llama_model * model) {
- return model->hparams.n_vocab;
-}
-
-int32_t llama_n_ctx_train(const struct llama_model * model) {
- return model->hparams.n_ctx_train;
-}
-
-int32_t llama_n_embd(const struct llama_model * model) {
- return model->hparams.n_embd;
-}
-
-int32_t llama_n_layer(const struct llama_model * model) {
- return model->hparams.n_layer;
-}
-
float llama_rope_freq_scale_train(const struct llama_model * model) {
return model->hparams.rope_freq_scale_train;
}
// TODO: add more model-specific info which should prevent loading the session file if not identical
}
- void write_rng(const std::mt19937 & rng) {
- std::ostringstream rng_ss;
- rng_ss << rng;
+ //void write_rng(const std::mt19937 & rng) {
+ // std::ostringstream rng_ss;
+ // rng_ss << rng;
- const std::string & rng_str = rng_ss.str();
+ // const std::string & rng_str = rng_ss.str();
- write_string(rng_str);
- }
+ // write_string(rng_str);
+ //}
void write_output_ids(struct llama_context * ctx) {
llama_output_reorder(ctx);
// TODO: add more info which needs to be identical but which is not verified otherwise
}
- void read_rng(std::mt19937 & rng) {
- std::string rng_str;
- read_string(rng_str);
+ //void read_rng(std::mt19937 & rng) {
+ // std::string rng_str;
+ // read_string(rng_str);
- std::istringstream rng_ss(rng_str);
- rng_ss >> rng;
+ // std::istringstream rng_ss(rng_str);
+ // rng_ss >> rng;
- if (rng_ss.fail()) {
- throw std::runtime_error("failed to load RNG state");
- }
- }
+ // if (rng_ss.fail()) {
+ // throw std::runtime_error("failed to load RNG state");
+ // }
+ //}
void read_output_ids(struct llama_context * ctx) {
std::vector<int32_t> output_pos;
data_ctx.write_model_info(ctx);
- data_ctx.write_rng(ctx->sampling.rng);
-
// copy outputs
data_ctx.write_output_ids(ctx);
data_ctx.write_logits(ctx);
data_ctx.read_model_info(ctx);
- // set rng
- data_ctx.read_rng(ctx->sampling.rng);
-
// set outputs
data_ctx.read_output_ids(ctx);
data_ctx.read_logits(ctx);
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
GGML_ABORT("fatal error");
-#endif
+#else
return nullptr;
+#endif
}
}
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG
GGML_ABORT("fatal error");
-#endif
+#else
return nullptr;
+#endif
}
}
}
//
-// grammar
+// sampling
//
-struct llama_grammar * llama_grammar_init(
- const llama_grammar_element ** rules,
- size_t n_rules,
- size_t start_rule_index) {
- return llama_grammar_init_impl(rules, n_rules, start_rule_index);
-}
-
-void llama_grammar_free(struct llama_grammar * grammar) {
- llama_grammar_free_impl(grammar);
-}
-
-struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
- return llama_grammar_copy_impl(grammar);
-}
-
-void llama_grammar_sample(
- const struct llama_grammar * grammar,
- const struct llama_context * ctx,
- llama_token_data_array * candidates) {
- llama_grammar_sample_impl(grammar, &ctx->model.vocab, &ctx->sampling, candidates);
-}
-
-void llama_sample_grammar(
- struct llama_context * ctx,
- llama_token_data_array * candidates,
- const struct llama_grammar * grammar) {
- llama_grammar_sample(grammar, ctx, candidates);
-}
-
-void llama_grammar_accept_token(
- struct llama_grammar * grammar,
- struct llama_context * ctx,
- llama_token token) {
- llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token);
+// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
+struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
+ return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
}
//
-// sampling
+// model split
//
-void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
- llama_set_rng_seed_impl(&ctx->sampling, seed);
-}
-
-void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
- llama_sample_softmax_impl(ctx ? &ctx->sampling : nullptr, candidates);
-}
-
-void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
- llama_sample_top_k_impl(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep);
-}
-
-void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
- llama_sample_top_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
-}
-
-void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
- llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
-}
-
-void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
- llama_sample_tail_free_impl(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep);
-}
-
-void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
- llama_sample_typical_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
-}
-
-void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
- llama_sample_entropy_impl(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val);
-}
-
-void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
- llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp);
-}
-
-void llama_sample_repetition_penalties(
- struct llama_context * ctx,
- llama_token_data_array * candidates,
- const llama_token * last_tokens,
- size_t penalty_last_n,
- float penalty_repeat,
- float penalty_freq,
- float penalty_present) {
- llama_sample_repetition_penalties_impl(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
-}
-
-void llama_sample_apply_guidance(
- struct llama_context * ctx,
- float * logits,
- float * logits_guidance,
- float scale) {
- llama_sample_apply_guidance_impl(&ctx->sampling, logits, logits_guidance, scale);
-}
-
-llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
- return llama_sample_token_mirostat_impl(&ctx->sampling, candidates, tau, eta, m, mu);
-}
-
-llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
- return llama_sample_token_mirostat_v2_impl(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu);
-}
-
-llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
- return llama_sample_token_greedy_impl(ctx ? &ctx->sampling : nullptr, candidates);
-}
-
-llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
- return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, rng);
-}
-
-llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
- return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng);
-}
-
int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
return 0;
}
-struct llama_timings llama_get_timings(struct llama_context * ctx) {
- struct llama_timings result = {
- /*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
- /*.t_end_ms =*/ 1.00 * ggml_time_ms(),
- /*.t_load_ms =*/ 1e-3 * ctx->t_load_us,
- /*.t_sample_ms =*/ 1e-3 * ctx->sampling.t_sample_us,
- /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us,
- /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us,
-
- /*.n_sample =*/ std::max(1, ctx->sampling.n_sample),
- /*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
- /*.n_eval =*/ std::max(1, ctx->n_eval),
- };
-
- return result;
-}
-
-void llama_print_timings(struct llama_context * ctx) {
- const llama_timings timings = llama_get_timings(ctx);
-
- LLAMA_LOG_INFO("\n");
- LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms);
- LLAMA_LOG_INFO("%s: sample time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
- __func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample);
- LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
- __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
- LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
- __func__, timings.t_eval_ms, timings.n_eval, timings.t_eval_ms / timings.n_eval, 1e3 / timings.t_eval_ms * timings.n_eval);
- LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval));
-}
-
-void llama_reset_timings(struct llama_context * ctx) {
- ctx->t_start_us = ggml_time_us();
- ctx->t_eval_us = ctx->n_eval = 0;
- ctx->t_p_eval_us = ctx->n_p_eval = 0;
-
- ctx->sampling.reset_timings();
-}
-
const char * llama_print_system_info(void) {
static std::string s;
return s.c_str();
}
-void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
+void llama_perf_print(const void * ctx, enum llama_perf_type type) {
+ switch (type) {
+ case LLAMA_PERF_TYPE_CONTEXT:
+ {
+ const auto * p = (const struct llama_context *) ctx;
+
+ const double t_start_ms = 1e-3 * p->t_start_us;
+ const double t_end_ms = 1.00 * ggml_time_ms();
+ const double t_load_ms = 1e-3 * p->t_load_us;
+ const double t_p_eval_ms = 1e-3 * p->t_p_eval_us;
+ const double t_eval_ms = 1e-3 * p->t_eval_us;
+
+ const int32_t n_p_eval = std::max(0, p->n_p_eval);
+ const int32_t n_eval = std::max(1, p->n_eval);
+
+ LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, t_load_ms);
+ LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
+ __func__, t_p_eval_ms, n_p_eval, t_p_eval_ms / n_p_eval, 1e3 / t_p_eval_ms * n_p_eval);
+ LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
+ __func__, t_eval_ms, n_eval, t_eval_ms / n_eval, 1e3 / t_eval_ms * n_eval);
+ LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - t_start_ms), (n_p_eval + n_eval));
+ } break;
+ case LLAMA_PERF_TYPE_SAMPLER_CHAIN:
+ {
+ const auto * smpl = (const struct llama_sampler *) ctx;
+ const auto * p = (const struct llama_sampler_chain *) smpl->ctx;
+
+ const double t_sampler_ms = 1e-3 * p->t_sample_us;
+
+ const int32_t n_sampler = std::max(0, p->n_sample);
+
+ LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
+ __func__, t_sampler_ms, n_sampler, t_sampler_ms / n_sampler, 1e3 / t_sampler_ms * n_sampler);
+ } break;
+ default:
+ GGML_ABORT("invalid perf type");
+ }
+}
+
+void llama_perf_reset(void * ctx, enum llama_perf_type type) {
+ switch (type) {
+ case LLAMA_PERF_TYPE_CONTEXT:
+ {
+ auto * p = (struct llama_context *) ctx;
+
+ p->t_start_us = ggml_time_us();
+ p->t_eval_us = p->n_eval = 0;
+ p->t_p_eval_us = p->n_p_eval = 0;
+ } break;
+ case LLAMA_PERF_TYPE_SAMPLER_CHAIN:
+ {
+ auto * smpl = (struct llama_sampler *) ctx;
+ auto * p = (struct llama_sampler_chain *) smpl->ctx;
+
+ p->t_sample_us = p->n_sample = 0;
+ } break;
+ default:
+ GGML_ABORT("invalid perf type");
+ }
+}
+
+void llama_perf_dump_yaml(FILE * stream, const llama_context * ctx) {
fprintf(stream, "\n");
fprintf(stream, "###########\n");
fprintf(stream, "# Timings #\n");
1.0e-3 * ctx->t_eval_us / ctx->n_eval);
fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n",
1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval);
- fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n",
- 1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample);
fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval);
fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval);
- fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->sampling.n_sample);
fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us);
fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us);
fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
- fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->sampling.t_sample_us);
fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n",
1.0e6 * ctx->n_eval / ctx->t_eval_us);
fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n",
1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us);
- fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n",
- 1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us);
}
// For internal test use
#undef NDEBUG
#endif
-#define LLAMA_API_INTERNAL
-
-#include "ggml.h"
-#include "llama.h"
-#include "grammar-parser.h"
-#include "json-schema-to-grammar.h"
#include "unicode.h"
+#include "llama-grammar.h"
+#include "json-schema-to-grammar.h"
+
#include <cassert>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
-static llama_grammar* build_grammar(const std::string & grammar_str) {
- auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
-
- // Ensure we parsed correctly
- assert(!parsed_grammar.rules.empty());
-
- // Ensure we have a root node
- assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
-
- std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
- llama_grammar* grammar = llama_grammar_init(
- grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
-
- return grammar;
+static llama_grammar * build_grammar(const std::string & grammar_str) {
+ return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
}
static bool test_build_grammar_fails(const std::string & grammar_str) {
}
static bool match_string(const std::string & input, llama_grammar * grammar) {
- auto decoded = decode_utf8(input, {});
-
- const auto & code_points = decoded.first;
+ const auto cpts = unicode_cpts_from_utf8(input);
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
- llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
+ llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
- for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
- const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
+ for (const auto & cpt : cpts) {
+ const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
- llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
+ llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
- if (cur_stacks.empty()) {
+ if (stacks_cur.empty()) {
// no stacks means that the grammar failed to match at this point
return false;
}
}
- for (const auto & stack : cur_stacks) {
+ for (const auto & stack : stacks_cur) {
if (stack.empty()) {
// An empty stack means that the grammar has been completed
return true;
fprintf(stderr, "âš« Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str());
fflush(stderr);
- auto grammar = build_grammar(grammar_str);
+ auto * grammar = build_grammar(grammar_str);
// Save the original grammar stacks so that we can reset after every new string we want to test
- const llama_grammar_stacks original_stacks = llama_grammar_get_stacks(grammar);
+ const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar);
- llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
+ llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
fprintf(stderr, " 🔵 Valid strings:\n");
assert(matched);
// Reset the grammar stacks
- cur_stacks = original_stacks;
+ stacks_cur = stacks_org;
}
fprintf(stderr, " 🟠Invalid strings:\n");
assert(!matched);
// Reset the grammar stacks
- cur_stacks = original_stacks;
+ stacks_cur = stacks_org;
}
// Clean up allocated memory
- llama_grammar_free(grammar);
+ llama_grammar_free_impl(grammar);
}
static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
term ::= number
number ::= [0-9]+)""";
- grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
+ llama_grammar_parser parsed_grammar;
+ parsed_grammar.parse(grammar_str.c_str());
// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());
fprintf(stderr, " Expected error: ");
- grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
+ llama_grammar_parser parsed_grammar;
+ parsed_grammar.parse(grammar_str.c_str());
// Ensure we did NOT parsed correctly
assert(parsed_grammar.rules.empty());
#endif
#include "llama.h"
-#include "grammar-parser.h"
+#include "llama-grammar.h"
#include <cassert>
static void verify_parsing(const char *grammar_bytes, const std::vector<std::pair<std::string, uint32_t>> expected, const std::vector<llama_grammar_element> &expected_rules) {
uint32_t index = 0;
- grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_bytes);
+ llama_grammar_parser parsed_grammar;
+ parsed_grammar.parse(grammar_bytes);
std::map<uint32_t, std::string> symbol_names;
for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) {
}
}
-static void verify_failure(const char *grammar_bytes) {
+static void verify_failure(const char * grammar_bytes) {
fprintf(stderr, "Testing expected failure:%s\n", grammar_bytes);
- auto result = grammar_parser::parse(grammar_bytes);
+ llama_grammar_parser result;
+ result.parse(grammar_bytes);
assert(result.rules.empty() && "should have failed");
}
#undef NDEBUG
#endif
+#include "json-schema-to-grammar.h"
+
+#include "llama-grammar.h"
+
#include <cassert>
#include <fstream>
#include <sstream>
#include <regex>
-#include "json-schema-to-grammar.h"
-#include "grammar-parser.h"
-
static std::string trim(const std::string & source) {
std::string s(source);
s.erase(0,s.find_first_not_of(" \n\r\t"));
}
void verify_expectation_parseable() const {
try {
- auto state = grammar_parser::parse(expected_grammar.c_str());
+ llama_grammar_parser state;
+ state.parse(expected_grammar.c_str());
if (state.symbol_ids.find("root") == state.symbol_ids.end()) {
throw std::runtime_error("Grammar failed to parse:\n" + expected_grammar);
}
#undef NDEBUG
#endif
-#define LLAMA_API_INTERNAL
#include "llama.h"
-#include "grammar-parser.h"
+#include "llama-grammar.h"
#include <cassert>
#include <stdexcept>
int main()
{
- grammar_parser::parse_state parsed_grammar;
+ llama_grammar_parser parsed_grammar;
std::vector<std::pair<std::string, uint32_t>> expected = {
{"expr", 2},
llama_grammar * grammar = NULL;
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
- grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+ grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
if (grammar == nullptr)
{
throw std::runtime_error("Failed to initialize llama_grammar");
}};
auto index = 0;
- for (auto stack : llama_grammar_get_stacks(grammar))
+ for (const llama_grammar_stack & stack : llama_grammar_get_stacks(grammar))
{
// compare stack to expected_stack
for (uint32_t i = 0; i < stack.size(); i++)
{
- auto element = stack[i];
- auto expected_element = expected_stacks[index][i];
+ const llama_grammar_element * element = stack[i];
+ const llama_grammar_element & expected_element = expected_stacks[index][i];
// pretty print error message before asserting
if (expected_element.type != element->type || expected_element.value != element->value)
delete[] candidate.code_points;
candidate.code_points = nullptr;
}
- llama_grammar_free(grammar);
+
+ llama_grammar_free_impl(grammar);
+
return 0;
}
#include "ggml.h"
#include "llama.h"
+#include "llama-sampling.h"
#ifdef NDEBUG
#undef NDEBUG
#include <string>
#include <vector>
-static void dump(const llama_token_data_array * candidates) {
- for (size_t i = 0; i < candidates->size; i++) {
- printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit);
+static void dump(const llama_token_data_array * cur_p) {
+ for (size_t i = 0; i < cur_p->size; i++) {
+ printf("%d: %f (%f)\n", cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}
}
-#define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0)
+#define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0)
+
+#define APPLY(__cnstr, __cur_p) do { \
+ auto * cnstr = (__cnstr); \
+ llama_sampler_apply(cnstr, (__cur_p)); \
+ llama_sampler_free(cnstr); \
+} while(0)
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
const size_t n_vocab = probs.size();
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
+
+ std::vector<llama_token_data> cur;
+ cur.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
const float logit = logf(probs[token_id]);
- candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+ cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
}
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
- llama_sample_softmax(nullptr, &candidates_p);
- DUMP(&candidates_p);
- llama_sample_top_k(nullptr, &candidates_p, k, 1);
- DUMP(&candidates_p);
+ llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+ APPLY(llama_sampler_init_softmax(), &cur_p);
+ DUMP(&cur_p);
+ APPLY(llama_sampler_init_top_k(k), &cur_p);
+ DUMP(&cur_p);
- GGML_ASSERT(candidates_p.size == expected_probs.size());
- for (size_t i = 0; i < candidates_p.size; i++) {
- GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
+ GGML_ASSERT(cur_p.size == expected_probs.size());
+ for (size_t i = 0; i < cur_p.size; i++) {
+ GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5);
}
}
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
const size_t n_vocab = probs.size();
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
+
+ std::vector<llama_token_data> cur;
+ cur.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
const float logit = logf(probs[token_id]);
- candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+ cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
}
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
- llama_sample_softmax(nullptr, &candidates_p);
- DUMP(&candidates_p);
- llama_sample_top_p(nullptr, &candidates_p, p, 1);
- DUMP(&candidates_p);
+ llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+ APPLY(llama_sampler_init_softmax(), &cur_p);
+ DUMP(&cur_p);
+ APPLY(llama_sampler_init_top_p(p, 1), &cur_p);
+ DUMP(&cur_p);
- GGML_ASSERT(candidates_p.size == expected_probs.size());
- for (size_t i = 0; i < candidates_p.size; i++) {
- GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
+ GGML_ASSERT(cur_p.size == expected_probs.size());
+ for (size_t i = 0; i < cur_p.size; i++) {
+ GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
}
}
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
const size_t n_vocab = probs.size();
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
+
+ std::vector<llama_token_data> cur;
+ cur.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
const float logit = logf(probs[token_id]);
- candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+ cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
}
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
- DUMP(&candidates_p);
- llama_sample_tail_free(nullptr, &candidates_p, z, 1);
- DUMP(&candidates_p);
+ llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+ DUMP(&cur_p);
+ APPLY(llama_sampler_init_tail_free(z, 1), &cur_p);
+ DUMP(&cur_p);
- GGML_ASSERT(candidates_p.size == expected_probs.size());
- for (size_t i = 0; i < candidates_p.size; i++) {
- GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
+ GGML_ASSERT(cur_p.size == expected_probs.size());
+ for (size_t i = 0; i < cur_p.size; i++) {
+ GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
}
}
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
const size_t n_vocab = probs.size();
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
+
+ std::vector<llama_token_data> cur;
+ cur.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
const float logit = logf(probs[token_id]);
- candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+ cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
}
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
- DUMP(&candidates_p);
- llama_sample_min_p(nullptr, &candidates_p, p, 1);
- DUMP(&candidates_p);
- llama_sample_softmax(nullptr, &candidates_p);
+ llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+ DUMP(&cur_p);
+ APPLY(llama_sampler_init_min_p(p, 1), &cur_p);
+ DUMP(&cur_p);
+ APPLY(llama_sampler_init_softmax(), &cur_p);
- GGML_ASSERT(candidates_p.size == expected_probs.size());
- for (size_t i = 0; i < candidates_p.size; i++) {
- GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
+ GGML_ASSERT(cur_p.size == expected_probs.size());
+ for (size_t i = 0; i < cur_p.size; i++) {
+ GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
}
}
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
const size_t n_vocab = probs.size();
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
+
+ std::vector<llama_token_data> cur;
+ cur.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
const float logit = logf(probs[token_id]);
- candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+ cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
}
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
- DUMP(&candidates_p);
- llama_sample_typical(nullptr, &candidates_p, p, 1);
- DUMP(&candidates_p);
+ llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+ DUMP(&cur_p);
+ APPLY(llama_sampler_init_typical(p, 1), &cur_p);
+ DUMP(&cur_p);
- GGML_ASSERT(candidates_p.size == expected_probs.size());
- for (size_t i = 0; i < candidates_p.size; i++) {
- GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
+ GGML_ASSERT(cur_p.size == expected_probs.size());
+ for (size_t i = 0; i < cur_p.size; i++) {
+ GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
}
}
-static void test_repetition_penalties(
+static void test_penalties(
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
) {
GGML_ASSERT(probs.size() == expected_probs.size());
const size_t n_vocab = probs.size();
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
+
+ std::vector<llama_token_data> cur;
+ cur.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
const float logit = logf(probs[token_id]);
- candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+ cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
+ }
+
+ llama_token_cnt token_count;
+ for (size_t i = 0; i < last_tokens.size(); i++) {
+ token_count[last_tokens[i]]++;
}
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
- llama_sample_softmax(nullptr, &candidates_p);
- DUMP(&candidates_p);
- llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
- llama_sample_softmax(nullptr, &candidates_p);
- DUMP(&candidates_p);
+ llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+ APPLY(llama_sampler_init_softmax(), &cur_p);
+ DUMP(&cur_p);
+ llama_sampler_penalties_impl(&cur_p, token_count, repeat_penalty, alpha_frequency, alpha_presence); // TODO: avoid
+ APPLY(llama_sampler_init_softmax(), &cur_p);
+ DUMP(&cur_p);
- GGML_ASSERT(candidates_p.size == expected_probs.size());
- for (size_t i = 0; i < candidates_p.size; i++) {
- GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
+ GGML_ASSERT(cur_p.size == expected_probs.size());
+ for (size_t i = 0; i < cur_p.size; i++) {
+ GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
}
}
-static void test_sampler_queue(
- const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p
+static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
) {
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
+ std::vector<llama_token_data> cur;
+ cur.reserve(n_vocab);
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
const float logit = logf(token_id);
- candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+ cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
}
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+ llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
llama_token min_token_id = 0;
const llama_token max_token_id = n_vocab-1;
for (auto s : samplers_sequence) {
switch (s){
- case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break;
+ case 'k': APPLY(llama_sampler_init_top_k(top_k), &cur_p); break;
case 'f': GGML_ABORT("tail_free test not implemented");
case 'y': GGML_ABORT("typical test not implemented");
- case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break;
- case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break;
+ case 'p': APPLY(llama_sampler_init_top_p(top_p, 1), &cur_p); break;
+ case 'm': APPLY(llama_sampler_init_min_p(min_p, 1), &cur_p); break;
case 't': GGML_ABORT("temperature test not implemented");
default : GGML_ABORT("Unknown sampler");
}
- llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests
+ APPLY(llama_sampler_init_softmax(), &cur_p); // make sure tokens are sorted for tests
- const int size = candidates_p.size;
+ const int size = cur_p.size;
if (s == 'k') {
const int expected_size = std::min(size, top_k);
min_token_id = std::max(min_token_id, (llama_token)(n_vocab - top_k));
GGML_ASSERT(size == expected_size);
- GGML_ASSERT(candidates_p.data[0].id == max_token_id);
- GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
+ GGML_ASSERT(cur_p.data[0].id == max_token_id);
+ GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
} else if (s == 'p') {
const int softmax_divisor = n_vocab * (n_vocab-1) / 2 - min_token_id * (min_token_id-1) / 2;
const int softmax_numerator_target = ceilf(top_p * softmax_divisor);
}
GGML_ASSERT(size == expected_size);
- GGML_ASSERT(candidates_p.data[0].id == max_token_id);
- GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
+ GGML_ASSERT(cur_p.data[0].id == max_token_id);
+ GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
} else if (s == 'm') {
int expected_size = ceilf((1.0f-min_p) * n_vocab);
expected_size = std::max(expected_size, 1);
min_token_id = std::min(min_token_id, (llama_token)(n_vocab - 1));
GGML_ASSERT(size == expected_size);
- GGML_ASSERT(candidates_p.data[0].id == max_token_id);
- GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
+ GGML_ASSERT(cur_p.data[0].id == max_token_id);
+ GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
} else {
GGML_ABORT("fatal error");
}
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
- test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
- test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
- test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
+ test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
+ test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
+ test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
- test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
- test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
- test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
+ test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
+ test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
+ test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);