]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : refactor sampling v2 (#9294)
authorGeorgi Gerganov <redacted>
Sat, 7 Sep 2024 12:16:19 +0000 (15:16 +0300)
committerGitHub <redacted>
Sat, 7 Sep 2024 12:16:19 +0000 (15:16 +0300)
- Add `struct llama_sampler` and `struct llama_sampler_i`
- Add `llama_sampler_` API
- Add `llama_sampler_chain_` API for chaining multiple samplers
- Remove `LLAMA_API_INTERNAL`
- Add `llama_perf_` API and remove old `llama_print_timings` and `llama_reset_timings`

48 files changed:
Makefile
common/CMakeLists.txt
common/common.cpp
common/common.h
common/grammar-parser.cpp [deleted file]
common/grammar-parser.h [deleted file]
common/sampling.cpp
common/sampling.h
examples/batched-bench/batched-bench.cpp
examples/batched.swift/Sources/main.swift
examples/batched/batched.cpp
examples/embedding/embedding.cpp
examples/eval-callback/eval-callback.cpp
examples/gbnf-validator/gbnf-validator.cpp
examples/gritlm/gritlm.cpp
examples/imatrix/imatrix.cpp
examples/infill/infill.cpp
examples/llama-bench/llama-bench.cpp
examples/llama.android/llama/src/main/cpp/llama-android.cpp
examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
examples/llava/llava-cli.cpp
examples/llava/minicpmv-cli.cpp
examples/lookahead/lookahead.cpp
examples/lookup/lookup.cpp
examples/main/main.cpp
examples/parallel/parallel.cpp
examples/passkey/passkey.cpp
examples/perplexity/perplexity.cpp
examples/quantize-stats/quantize-stats.cpp
examples/retrieval/retrieval.cpp
examples/save-load-state/save-load-state.cpp
examples/server/README.md
examples/server/server.cpp
examples/simple/simple.cpp
examples/speculative/speculative.cpp
include/llama.h
src/llama-grammar.cpp
src/llama-grammar.h
src/llama-impl.h
src/llama-sampling.cpp
src/llama-sampling.h
src/llama-vocab.h
src/llama.cpp
tests/test-grammar-integration.cpp
tests/test-grammar-parser.cpp
tests/test-json-schema-to-grammar.cpp
tests/test-llama-grammar.cpp
tests/test-sampling.cpp

index 332496cfc39c1a6fac42b2bbc70d80194092ad36..89287831ff31f9e06f0575118440218a128d94d9 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -927,7 +927,6 @@ OBJ_COMMON = \
        common/ngram-cache.o \
        common/sampling.o \
        common/train.o \
-       common/grammar-parser.o \
        common/build-info.o \
        common/json-schema-to-grammar.o
 
@@ -1167,11 +1166,6 @@ common/console.o: \
        common/console.h
        $(CXX) $(CXXFLAGS) -c $< -o $@
 
-common/grammar-parser.o: \
-       common/grammar-parser.cpp \
-       common/grammar-parser.h
-       $(CXX) $(CXXFLAGS) -c $< -o $@
-
 common/json-schema-to-grammar.o: \
        common/json-schema-to-grammar.cpp \
        common/json-schema-to-grammar.h
index 761971d6881f38e29a963314c65fc19ed227acb9..2c72793b89dbe2f65cd34724cc7ceb608b28c510 100644 (file)
@@ -58,8 +58,6 @@ add_library(${TARGET} STATIC
     sampling.cpp
     console.h
     console.cpp
-    grammar-parser.h
-    grammar-parser.cpp
     json.hpp
     json-schema-to-grammar.cpp
     train.h
index de2a177c165b4e73bcb68475e0c08097af07cb69..6394301318c4bb8710447d12268cd046579469c9 100644 (file)
@@ -353,16 +353,15 @@ void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model)
 }
 
 bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
-    bool invalid_param = false;
-    std::string arg;
-    const std::string arg_prefix = "--";
-    llama_sampling_params & sparams = params.sparams;
-
     for (int i = 1; i < argc; i++) {
-        arg = argv[i];
+        const std::string arg_prefix = "--";
+
+        std::string arg = argv[i];
         if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
             std::replace(arg.begin(), arg.end(), '_', '-');
         }
+
+        bool invalid_param = false;
         if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) {
             throw std::invalid_argument("error: unknown argument: " + arg);
         }
@@ -386,11 +385,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
         get_env("HF_TOKEN", params.hf_token);
     }
 
+    auto & sparams = params.sparams;
+
     if (params.escape) {
         string_process_escapes(params.prompt);
         string_process_escapes(params.input_prefix);
         string_process_escapes(params.input_suffix);
-        string_process_escapes(sparams.cfg_negative_prompt);
         for (auto & antiprompt : params.antiprompt) {
             string_process_escapes(antiprompt);
         }
@@ -401,6 +401,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
         params.kv_overrides.back().key[0] = 0;
     }
 
+    if (sparams.seed == LLAMA_DEFAULT_SEED) {
+        sparams.seed = time(NULL);
+    }
+
     return true;
 }
 
@@ -526,12 +530,10 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
 bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
     const char split_delim = ',';
 
-    llama_sampling_params & sparams = params.sparams;
+    auto & sparams = params.sparams;
 
     if (arg == "-s" || arg == "--seed") {
         CHECK_ARG
-        // TODO: this is temporary, in the future the sampling state will be moved fully to llama_sampling_context.
-        params.seed = std::stoul(argv[i]);
         sparams.seed = std::stoul(argv[i]);
         return true;
     }
@@ -842,12 +844,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
     if (arg == "--samplers") {
         CHECK_ARG
         const auto sampler_names = string_split(argv[i], ';');
-        sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, true);
+        sparams.samplers = gpt_sampler_types_from_names(sampler_names, true);
         return true;
     }
     if (arg == "--sampling-seq") {
         CHECK_ARG
-        sparams.samplers_sequence = llama_sampling_types_from_chars(argv[i]);
+        sparams.samplers = gpt_sampler_types_from_chars(argv[i]);
         return true;
     }
     if (arg == "--top-p") {
@@ -873,7 +875,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
     }
     if (arg == "--typical") {
         CHECK_ARG
-        sparams.typical_p = std::stof(argv[i]);
+        sparams.typ_p = std::stof(argv[i]);
         return true;
     }
     if (arg == "--repeat-last-n") {
@@ -922,30 +924,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         sparams.mirostat_tau = std::stof(argv[i]);
         return true;
     }
-    if (arg == "--cfg-negative-prompt") {
-        CHECK_ARG
-        sparams.cfg_negative_prompt = argv[i];
-        return true;
-    }
-    if (arg == "--cfg-negative-prompt-file") {
-        CHECK_ARG
-        std::ifstream file(argv[i]);
-        if (!file) {
-            fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
-            invalid_param = true;
-            return true;
-        }
-        std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(sparams.cfg_negative_prompt));
-        if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') {
-            sparams.cfg_negative_prompt.pop_back();
-        }
-        return true;
-    }
-    if (arg == "--cfg-scale") {
-        CHECK_ARG
-        sparams.cfg_scale = std::stof(argv[i]);
-        return true;
-    }
     if (arg == "-b" || arg == "--batch-size") {
         CHECK_ARG
         params.n_batch = std::stoi(argv[i]);
@@ -1355,7 +1333,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         return true;
     }
     if (arg == "--ignore-eos") {
-        params.ignore_eos = true;
+        sparams.ignore_eos = true;
         return true;
     }
     if (arg == "--penalize-nl") {
@@ -1370,7 +1348,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         std::string value_str;
         try {
             if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
-                sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
+                const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
+                sparams.logit_bias.push_back({key, bias});
             }
             else {
                 throw std::exception();
@@ -1725,13 +1704,13 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
 #endif
 
 void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
-    const llama_sampling_params & sparams = params.sparams;
+    const auto & sparams = params.sparams;
 
     std::string sampler_type_chars;
     std::string sampler_type_names;
-    for (const auto sampler_type : sparams.samplers_sequence) {
-        sampler_type_chars += static_cast<char>(sampler_type);
-        sampler_type_names += llama_sampling_type_to_str(sampler_type) + ";";
+    for (const auto & sampler : sparams.samplers) {
+        sampler_type_chars += gpt_sampler_type_to_chr(sampler);
+        sampler_type_names += gpt_sampler_type_to_str(sampler) + ";";
     }
     sampler_type_names.pop_back();
 
@@ -1766,7 +1745,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
     options.push_back({ "*",           "       --verbose-prompt",       "print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false" });
     options.push_back({ "*",           "       --no-display-prompt",    "don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false" });
     options.push_back({ "*",           "-co,   --color",                "colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false" });
-    options.push_back({ "*",           "-s,    --seed SEED",            "RNG seed (default: %d, use random seed for < 0)", params.seed });
     options.push_back({ "*",           "-t,    --threads N",            "number of threads to use during generation (default: %d)", params.cpuparams.n_threads });
     options.push_back({ "*",           "-tb,   --threads-batch N",      "number of threads to use during batch and prompt processing (default: same as --threads)" });
     options.push_back({ "speculative", "-td,   --threads-draft N",      "number of threads to use during generation (default: same as --threads)" });
@@ -1846,18 +1824,19 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
                                        "       --spm-infill",           "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", params.spm_infill ? "enabled" : "disabled" });
 
     options.push_back({ "sampling" });
+    options.push_back({ "*",           "-s,    --seed SEED",            "RNG seed (default: %d, use random seed for < 0)", sparams.seed });
     options.push_back({ "*",           "       --samplers SAMPLERS",    "samplers that will be used for generation in the order, separated by \';\'\n"
                                                                         "(default: %s)", sampler_type_names.c_str() });
     options.push_back({ "*",           "       --sampling-seq SEQUENCE",
                                                                         "simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() });
     options.push_back({ "*",           "       --ignore-eos",           "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)" });
     options.push_back({ "*",           "       --penalize-nl",          "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" });
-    options.push_back({ "*",           "       --temp N",               "temperature (default: %.1f)", (double)sparams.temp });
+    options.push_back({ "*",           "       --temp T",               "temperature (default: %.1f)", (double)sparams.temp });
     options.push_back({ "*",           "       --top-k N",              "top-k sampling (default: %d, 0 = disabled)", sparams.top_k });
-    options.push_back({ "*",           "       --top-p N",              "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p });
-    options.push_back({ "*",           "       --min-p N",              "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p });
-    options.push_back({ "*",           "       --tfs N",                "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z });
-    options.push_back({ "*",           "       --typical N",            "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typical_p });
+    options.push_back({ "*",           "       --top-p P",              "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p });
+    options.push_back({ "*",           "       --min-p P",              "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p });
+    options.push_back({ "*",           "       --tfs P",                "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z });
+    options.push_back({ "*",           "       --typical P",            "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typ_p });
     options.push_back({ "*",           "       --repeat-last-n N",      "last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", sparams.penalty_last_n });
     options.push_back({ "*",           "       --repeat-penalty N",     "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat });
     options.push_back({ "*",           "       --presence-penalty N",   "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present });
@@ -1872,11 +1851,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
     options.push_back({ "*",           "       -l TOKEN_ID(+/-)BIAS",   "modifies the likelihood of token appearing in the completion,\n"
                                                                         "i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"
                                                                         "or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" });
-    options.push_back({ "main",        "       --cfg-negative-prompt PROMPT",
-                                                                        "negative prompt to use for guidance (default: '%s')", sparams.cfg_negative_prompt.c_str() });
-    options.push_back({ "main",        "       --cfg-negative-prompt-file FNAME",
-                                                                        "negative prompt file to use for guidance" });
-    options.push_back({ "main",        "       --cfg-scale N",          "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale });
     options.push_back({ "main",        "       --chat-template JINJA_TEMPLATE",
                                                                         "set custom jinja chat template (default: template taken from model's metadata)\n"
                                                                         "if suffix/prefix are specified, template will be disabled\n"
@@ -2528,8 +2502,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
         llama_lora_adapters_apply(lctx, iparams.lora_adapters);
     }
 
-    if (params.ignore_eos) {
-        params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
+    if (params.sparams.ignore_eos && llama_token_eos(model) == -1) {
+        fprintf(stderr, "%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
+        params.sparams.ignore_eos = false;
     }
 
     if (params.warmup) {
@@ -2558,7 +2533,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
         }
         llama_kv_cache_clear(lctx);
         llama_synchronize(lctx);
-        llama_reset_timings(lctx);
+        llama_perf_reset(lctx, LLAMA_PERF_TYPE_CONTEXT);
     }
 
     iparams.model   = model;
@@ -2637,7 +2612,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
     cparams.n_threads         = params.cpuparams.n_threads;
     cparams.n_threads_batch   = params.cpuparams_batch.n_threads == -1 ?
                                     params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
-    cparams.seed              = params.seed;
     cparams.logits_all        = params.logits_all;
     cparams.embeddings        = params.embedding;
     cparams.rope_scaling_type = params.rope_scaling_type;
@@ -3523,7 +3497,7 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha
 
 void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx,
                                const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
-    const llama_sampling_params & sparams = params.sparams;
+    const auto & sparams = params.sparams;
 
     fprintf(stream, "build_commit: %s\n",        LLAMA_COMMIT);
     fprintf(stream, "build_number: %d\n",        LLAMA_BUILD_NUMBER);
@@ -3574,8 +3548,6 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
 
     fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str());
     fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch);
-    yaml_dump_string_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str());
-    fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale);
     fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
     fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
     fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
@@ -3586,10 +3558,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
     fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
     fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
     fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
-
-    const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
-    const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
-    fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
+    fprintf(stream, "ignore_eos: %s # default: false\n", sparams.ignore_eos ? "true" : "false");
 
     yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str());
     fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false");
@@ -3600,11 +3569,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
     fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str());
 
     fprintf(stream, "logit_bias:\n");
-    for (std::pair<llama_token, float> lb : sparams.logit_bias) {
-        if (ignore_eos && lb.first == logit_bias_eos->first) {
-            continue;
-        }
-        fprintf(stream, "  %d: %f", lb.first, lb.second);
+    for (const auto & logit_bias : sparams.logit_bias) {
+        fprintf(stream, "  %d: %f", logit_bias.token, logit_bias.bias);
     }
 
     fprintf(stream, "lora:\n");
@@ -3657,7 +3623,6 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
 
     fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base);
     fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale);
-    fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
     fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
     fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
     fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
@@ -3671,7 +3636,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
     fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
     fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
     fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
-    fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
+    fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p);
     fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
     fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
 }
index 795ff44054d403435606a6455863c44da263fd95..3a6c8e0b5377ab18ca51ee3545378b4d97af6cb0 100644 (file)
@@ -77,8 +77,6 @@ struct cpu_params {
 };
 
 struct gpt_params {
-    uint32_t seed                 = LLAMA_DEFAULT_SEED; // RNG seed
-
     int32_t n_predict             =    -1; // new tokens to predict
     int32_t n_ctx                 =     0; // context size
     int32_t n_batch               =  2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
@@ -120,8 +118,7 @@ struct gpt_params {
     enum llama_pooling_type      pooling_type      = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
     enum llama_attention_type    attention_type    = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
 
-    // // sampling parameters
-    struct llama_sampling_params sparams;
+    struct gpt_sampler_params sparams;
 
     std::string model                = ""; // model path
     std::string model_draft          = ""; // draft model for speculative decoding
@@ -185,7 +182,6 @@ struct gpt_params {
     bool flash_attn        = false; // flash attention
 
     bool input_prefix_bos  = false; // prefix BOS to user inputs, preceding input_prefix
-    bool ignore_eos        = false; // ignore generated EOS tokens
     bool logits_all        = false; // return logits for all tokens in the batch
     bool use_mmap          = true;  // use mmap for faster loads
     bool use_mlock         = false; // use mlock to keep model in memory
diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp
deleted file mode 100644 (file)
index 438452e..0000000
+++ /dev/null
@@ -1,539 +0,0 @@
-#include "grammar-parser.h"
-#include <cstdint>
-#include <cwchar>
-#include <string>
-#include <utility>
-#include <stdexcept>
-#include <exception>
-
-namespace grammar_parser {
-    // NOTE: assumes valid utf8 (but checks for overrun)
-    // copied from llama.cpp
-    static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
-        static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
-        uint8_t  first_byte = static_cast<uint8_t>(*src);
-        uint8_t  highbits   = first_byte >> 4;
-        int      len        = lookup[highbits];
-        uint8_t  mask       = (1 << (8 - len)) - 1;
-        uint32_t value      = first_byte & mask;
-        const char * end    = src + len; // may overrun!
-        const char * pos    = src + 1;
-        for ( ; pos < end && *pos; pos++) {
-            value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
-        }
-        return std::make_pair(value, pos);
-    }
-
-    static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
-        uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
-        auto result = state.symbol_ids.emplace(std::string(src, len), next_id);
-        return result.first->second;
-    }
-
-    static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
-        uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
-        state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
-        return next_id;
-    }
-
-    static void add_rule(
-            parse_state & state,
-            uint32_t      rule_id,
-            const std::vector<llama_grammar_element> & rule) {
-        if (state.rules.size() <= rule_id) {
-            state.rules.resize(rule_id + 1);
-        }
-        state.rules[rule_id] = rule;
-    }
-
-    static bool is_digit_char(char c) {
-        return '0' <= c && c <= '9';
-    }
-
-    static bool is_word_char(char c) {
-        return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
-    }
-
-    static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
-        const char * pos   = src;
-        const char * end   = src + size;
-        uint32_t     value = 0;
-        for ( ; pos < end && *pos; pos++) {
-            value <<= 4;
-            char c = *pos;
-            if ('a' <= c && c <= 'f') {
-                value += c - 'a' + 10;
-            } else if ('A' <= c && c <= 'F') {
-                value += c - 'A' + 10;
-            } else if ('0' <= c && c <= '9') {
-                value += c - '0';
-            } else {
-                break;
-            }
-        }
-        if (pos != end) {
-            throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
-        }
-        return std::make_pair(value, pos);
-    }
-
-    static const char * parse_space(const char * src, bool newline_ok) {
-        const char * pos = src;
-        while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
-                (newline_ok && (*pos == '\r' || *pos == '\n'))) {
-            if (*pos == '#') {
-                while (*pos && *pos != '\r' && *pos != '\n') {
-                    pos++;
-                }
-            } else {
-                pos++;
-            }
-        }
-        return pos;
-    }
-
-    static const char * parse_name(const char * src) {
-        const char * pos = src;
-        while (is_word_char(*pos)) {
-            pos++;
-        }
-        if (pos == src) {
-            throw std::runtime_error(std::string("expecting name at ") + src);
-        }
-        return pos;
-    }
-
-    static const char * parse_int(const char * src) {
-        const char * pos = src;
-        while (is_digit_char(*pos)) {
-            pos++;
-        }
-        if (pos == src) {
-            throw std::runtime_error(std::string("expecting integer at ") + src);
-        }
-        return pos;
-    }
-
-    static std::pair<uint32_t, const char *> parse_char(const char * src) {
-        if (*src == '\\') {
-            switch (src[1]) {
-                case 'x': return parse_hex(src + 2, 2);
-                case 'u': return parse_hex(src + 2, 4);
-                case 'U': return parse_hex(src + 2, 8);
-                case 't': return std::make_pair('\t', src + 2);
-                case 'r': return std::make_pair('\r', src + 2);
-                case 'n': return std::make_pair('\n', src + 2);
-                case '\\':
-                case '"':
-                case '[':
-                case ']':
-                    return std::make_pair(src[1], src + 2);
-                default:
-                    throw std::runtime_error(std::string("unknown escape at ") + src);
-            }
-        } else if (*src) {
-            return decode_utf8(src);
-        }
-        throw std::runtime_error("unexpected end of input");
-    }
-
-    const char * parse_alternates(
-            parse_state       & state,
-            const char        * src,
-            const std::string & rule_name,
-            uint32_t            rule_id,
-            bool                is_nested);
-
-    static const char * parse_sequence(
-            parse_state                        & state,
-            const char                         * src,
-            const std::string                  & rule_name,
-            std::vector<llama_grammar_element> & out_elements,
-            bool                                 is_nested) {
-        size_t last_sym_start = out_elements.size();
-        const char * pos = src;
-
-        auto handle_repetitions = [&](int min_times, int max_times) {
-
-            if (last_sym_start == out_elements.size()) {
-                throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
-            }
-
-            // apply transformation to previous symbol (last_sym_start to end) according to
-            // the following rewrite rules:
-            // S{m,n} --> S S S (m times) S'(n-m)
-            //            S'(x)   ::= S S'(x-1) |
-            //            (... n-m definitions of these S' rules ...)
-            //            S'(1)   ::= S |
-            // S{m,} -->  S S S (m times) S'
-            //            S'     ::= S S' |
-            // S*     --> S{0,}
-            //        --> S'     ::= S S' |
-            // S+     --> S{1,}
-            //        --> S S'
-            //            S'     ::= S S' |
-            // S?     --> S{0,1}
-            //        --> S'
-            //            S'     ::= S |
-
-            std::vector<llama_grammar_element> previous_elements(out_elements.begin() + last_sym_start, out_elements.end());
-            if (min_times == 0) {
-                out_elements.resize(last_sym_start);
-            } else {
-                // Repeat the previous elements (min_times - 1) times
-                for (int i = 1; i < min_times; i++) {
-                    out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end());
-                }
-            }
-
-            uint32_t last_rec_rule_id = 0;
-            auto n_opt = max_times < 0 ? 1 : max_times - min_times;
-
-            std::vector<llama_grammar_element> rec_rule(previous_elements);
-            for (int i = 0; i < n_opt; i++) {
-                rec_rule.resize(previous_elements.size());
-                uint32_t rec_rule_id = generate_symbol_id(state, rule_name);
-                if (i > 0 || max_times < 0) {
-                    rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
-                }
-                rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
-                rec_rule.push_back({LLAMA_GRETYPE_END, 0});
-                add_rule(state, rec_rule_id, rec_rule);
-                last_rec_rule_id = rec_rule_id;
-            }
-            if (n_opt > 0) {
-                out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
-            }
-        };
-
-        while (*pos) {
-            if (*pos == '"') { // literal string
-                pos++;
-                last_sym_start = out_elements.size();
-                while (*pos != '"') {
-                    if (!*pos) {
-                        throw std::runtime_error("unexpected end of input");
-                    }
-                    auto char_pair = parse_char(pos);
-                         pos       = char_pair.second;
-                    out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
-                }
-                pos = parse_space(pos + 1, is_nested);
-            } else if (*pos == '[') { // char range(s)
-                pos++;
-                enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
-                if (*pos == '^') {
-                    pos++;
-                    start_type = LLAMA_GRETYPE_CHAR_NOT;
-                }
-                last_sym_start = out_elements.size();
-                while (*pos != ']') {
-                    if (!*pos) {
-                        throw std::runtime_error("unexpected end of input");
-                    }
-                    auto char_pair = parse_char(pos);
-                         pos       = char_pair.second;
-                    enum llama_gretype type = last_sym_start < out_elements.size()
-                        ? LLAMA_GRETYPE_CHAR_ALT
-                        : start_type;
-
-                    out_elements.push_back({type, char_pair.first});
-                    if (pos[0] == '-' && pos[1] != ']') {
-                        if (!pos[1]) {
-                            throw std::runtime_error("unexpected end of input");
-                        }
-                        auto endchar_pair = parse_char(pos + 1);
-                             pos          = endchar_pair.second;
-                        out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
-                    }
-                }
-                pos = parse_space(pos + 1, is_nested);
-            } else if (is_word_char(*pos)) { // rule reference
-                const char * name_end    = parse_name(pos);
-                uint32_t     ref_rule_id = get_symbol_id(state, pos, name_end - pos);
-                pos = parse_space(name_end, is_nested);
-                last_sym_start = out_elements.size();
-                out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
-            } else if (*pos == '(') { // grouping
-                // parse nested alternates into synthesized rule
-                pos = parse_space(pos + 1, true);
-                uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
-                pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
-                last_sym_start = out_elements.size();
-                // output reference to synthesized rule
-                out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
-                if (*pos != ')') {
-                    throw std::runtime_error(std::string("expecting ')' at ") + pos);
-                }
-                pos = parse_space(pos + 1, is_nested);
-            } else if (*pos == '.') { // any char
-                last_sym_start = out_elements.size();
-                out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
-                pos = parse_space(pos + 1, is_nested);
-            } else if (*pos == '*') {
-                pos = parse_space(pos + 1, is_nested);
-                handle_repetitions(0, -1);
-            } else if (*pos == '+') {
-                pos = parse_space(pos + 1, is_nested);
-                handle_repetitions(1, -1);
-            } else if (*pos == '?') {
-                pos = parse_space(pos + 1, is_nested);
-                handle_repetitions(0, 1);
-            } else if (*pos == '{') {
-                pos = parse_space(pos + 1, is_nested);
-
-                if (!is_digit_char(*pos)) {
-                    throw std::runtime_error(std::string("expecting an int at ") + pos);
-                }
-                const char * int_end = parse_int(pos);
-                int min_times = std::stoul(std::string(pos, int_end - pos));
-                pos = parse_space(int_end, is_nested);
-
-                int max_times = -1;
-
-                if (*pos == '}') {
-                    max_times = min_times;
-                    pos = parse_space(pos + 1, is_nested);
-                } else if (*pos == ',') {
-                    pos = parse_space(pos + 1, is_nested);
-
-                    if (is_digit_char(*pos)) {
-                        const char * int_end = parse_int(pos);
-                        max_times = std::stoul(std::string(pos, int_end - pos));
-                        pos = parse_space(int_end, is_nested);
-                    }
-
-                    if (*pos != '}') {
-                        throw std::runtime_error(std::string("expecting '}' at ") + pos);
-                    }
-                    pos = parse_space(pos + 1, is_nested);
-                } else {
-                    throw std::runtime_error(std::string("expecting ',' at ") + pos);
-                }
-                handle_repetitions(min_times, max_times);
-            } else {
-                break;
-            }
-        }
-        return pos;
-    }
-
-    const char * parse_alternates(
-            parse_state       & state,
-            const char        * src,
-            const std::string & rule_name,
-            uint32_t            rule_id,
-            bool                is_nested) {
-        std::vector<llama_grammar_element> rule;
-        const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
-        while (*pos == '|') {
-            rule.push_back({LLAMA_GRETYPE_ALT, 0});
-            pos = parse_space(pos + 1, true);
-            pos = parse_sequence(state, pos, rule_name, rule, is_nested);
-        }
-        rule.push_back({LLAMA_GRETYPE_END, 0});
-        add_rule(state, rule_id, rule);
-        return pos;
-    }
-
-    static const char * parse_rule(parse_state & state, const char * src) {
-        const char * name_end = parse_name(src);
-        const char * pos      = parse_space(name_end, false);
-        size_t       name_len = name_end - src;
-        uint32_t     rule_id  = get_symbol_id(state, src, name_len);
-        const std::string name(src, name_len);
-
-        if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
-            throw std::runtime_error(std::string("expecting ::= at ") + pos);
-        }
-        pos = parse_space(pos + 3, true);
-
-        pos = parse_alternates(state, pos, name, rule_id, false);
-
-        if (*pos == '\r') {
-            pos += pos[1] == '\n' ? 2 : 1;
-        } else if (*pos == '\n') {
-            pos++;
-        } else if (*pos) {
-            throw std::runtime_error(std::string("expecting newline or end at ") + pos);
-        }
-        return parse_space(pos, true);
-    }
-
-    parse_state parse(const char * src) {
-        try {
-            parse_state state;
-            const char * pos = parse_space(src, true);
-            while (*pos) {
-                pos = parse_rule(state, pos);
-            }
-            // Validate the state to ensure that all rules are defined
-            for (const auto & rule : state.rules) {
-                if (rule.empty()) {
-                    throw std::runtime_error("Undefined rule");
-                }
-                for (const auto & elem : rule) {
-                    if (elem.type == LLAMA_GRETYPE_RULE_REF) {
-                        // Ensure that the rule at that location exists
-                        if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) {
-                            // Get the name of the rule that is missing
-                            for (const auto & kv : state.symbol_ids) {
-                                if (kv.second == elem.value) {
-                                    throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
-                                }
-                            }
-                        }
-                    }
-                }
-            }
-            return state;
-        } catch (const std::exception & err) {
-            fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
-            return parse_state();
-        }
-    }
-
-    static void print_grammar_char(FILE * file, uint32_t c) {
-        if (0x20 <= c && c <= 0x7f) {
-            fprintf(file, "%c", static_cast<char>(c));
-        } else {
-            // cop out of encoding UTF-8
-            fprintf(file, "<U+%04X>", c);
-        }
-    }
-
-    static bool is_char_element(llama_grammar_element elem) {
-        switch (elem.type) {
-            case LLAMA_GRETYPE_CHAR:           return true;
-            case LLAMA_GRETYPE_CHAR_NOT:       return true;
-            case LLAMA_GRETYPE_CHAR_ALT:       return true;
-            case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
-            case LLAMA_GRETYPE_CHAR_ANY:       return true;
-            default:                           return false;
-        }
-    }
-
-    static void print_rule_binary(FILE * file, const std::vector<llama_grammar_element> & rule) {
-        for (auto elem : rule) {
-            switch (elem.type) {
-                case LLAMA_GRETYPE_END:            fprintf(file, "END");            break;
-                case LLAMA_GRETYPE_ALT:            fprintf(file, "ALT");            break;
-                case LLAMA_GRETYPE_RULE_REF:       fprintf(file, "RULE_REF");       break;
-                case LLAMA_GRETYPE_CHAR:           fprintf(file, "CHAR");           break;
-                case LLAMA_GRETYPE_CHAR_NOT:       fprintf(file, "CHAR_NOT");       break;
-                case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
-                case LLAMA_GRETYPE_CHAR_ALT:       fprintf(file, "CHAR_ALT");       break;
-                case LLAMA_GRETYPE_CHAR_ANY:       fprintf(file, "CHAR_ANY");       break;
-            }
-            switch (elem.type) {
-                case LLAMA_GRETYPE_END:
-                case LLAMA_GRETYPE_ALT:
-                case LLAMA_GRETYPE_RULE_REF:
-                    fprintf(file, "(%u) ", elem.value);
-                    break;
-                case LLAMA_GRETYPE_CHAR:
-                case LLAMA_GRETYPE_CHAR_NOT:
-                case LLAMA_GRETYPE_CHAR_RNG_UPPER:
-                case LLAMA_GRETYPE_CHAR_ALT:
-                case LLAMA_GRETYPE_CHAR_ANY:
-                    fprintf(file, "(\"");
-                    print_grammar_char(file, elem.value);
-                    fprintf(file, "\") ");
-                    break;
-            }
-        }
-        fprintf(file, "\n");
-    }
-
-    static void print_rule(
-            FILE     * file,
-            uint32_t   rule_id,
-            const std::vector<llama_grammar_element> & rule,
-            const std::map<uint32_t, std::string>    & symbol_id_names) {
-        if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
-            throw std::runtime_error(
-                "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
-        }
-        fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
-        for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
-            llama_grammar_element elem = rule[i];
-            switch (elem.type) {
-                case LLAMA_GRETYPE_END:
-                    throw std::runtime_error(
-                        "unexpected end of rule: " + std::to_string(rule_id) + "," +
-                        std::to_string(i));
-                case LLAMA_GRETYPE_ALT:
-                    fprintf(file, "| ");
-                    break;
-                case LLAMA_GRETYPE_RULE_REF:
-                    fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
-                    break;
-                case LLAMA_GRETYPE_CHAR:
-                    fprintf(file, "[");
-                    print_grammar_char(file, elem.value);
-                    break;
-                case LLAMA_GRETYPE_CHAR_NOT:
-                    fprintf(file, "[^");
-                    print_grammar_char(file, elem.value);
-                    break;
-                case LLAMA_GRETYPE_CHAR_RNG_UPPER:
-                    if (i == 0 || !is_char_element(rule[i - 1])) {
-                        throw std::runtime_error(
-                            "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
-                            std::to_string(rule_id) + "," + std::to_string(i));
-                    }
-                    fprintf(file, "-");
-                    print_grammar_char(file, elem.value);
-                    break;
-                case LLAMA_GRETYPE_CHAR_ALT:
-                    if (i == 0 || !is_char_element(rule[i - 1])) {
-                        throw std::runtime_error(
-                            "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
-                            std::to_string(rule_id) + "," + std::to_string(i));
-                    }
-                    print_grammar_char(file, elem.value);
-                    break;
-                case LLAMA_GRETYPE_CHAR_ANY:
-                    fprintf(file, ".");
-                    break;
-            }
-            if (is_char_element(elem)) {
-                switch (rule[i + 1].type) {
-                    case LLAMA_GRETYPE_CHAR_ALT:
-                    case LLAMA_GRETYPE_CHAR_RNG_UPPER:
-                    case LLAMA_GRETYPE_CHAR_ANY:
-                        break;
-                    default:
-                        fprintf(file, "] ");
-                }
-            }
-        }
-        fprintf(file, "\n");
-    }
-
-    void print_grammar(FILE * file, const parse_state & state) {
-        try {
-            std::map<uint32_t, std::string> symbol_id_names;
-            for (const auto & kv : state.symbol_ids) {
-                symbol_id_names[kv.second] = kv.first;
-            }
-            for (size_t i = 0, end = state.rules.size(); i < end; i++) {
-                // fprintf(file, "%zu: ", i);
-                // print_rule_binary(file, state.rules[i]);
-                print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
-                // fprintf(file, "\n");
-            }
-        } catch (const std::exception & err) {
-            fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
-        }
-    }
-
-    std::vector<const llama_grammar_element *> parse_state::c_rules() {
-        std::vector<const llama_grammar_element *> ret;
-        ret.reserve(rules.size());
-        for (const auto & rule : rules) {
-            ret.push_back(rule.data());
-        }
-        return ret;
-    }
-}
diff --git a/common/grammar-parser.h b/common/grammar-parser.h
deleted file mode 100644 (file)
index 9037d72..0000000
+++ /dev/null
@@ -1,29 +0,0 @@
-// Implements a parser for an extended Backus-Naur form (BNF), producing the
-// binary context-free grammar format specified by llama.h. Supports character
-// ranges, grouping, and repetition operators. As an example, a grammar for
-// arithmetic might look like:
-//
-// root  ::= expr
-// expr  ::= term ([-+*/] term)*
-// term  ::= num | "(" space expr ")" space
-// num   ::= [0-9]+ space
-// space ::= [ \t\n]*
-
-#pragma once
-#include "llama.h"
-#include <vector>
-#include <map>
-#include <cstdint>
-#include <string>
-
-namespace grammar_parser {
-    struct parse_state {
-        std::map<std::string, uint32_t>                 symbol_ids;
-        std::vector<std::vector<llama_grammar_element>> rules;
-
-        std::vector<const llama_grammar_element *> c_rules();
-    };
-
-    parse_state parse(const char * src);
-    void print_grammar(FILE * file, const parse_state & state);
-}
index 079e405168dff237f1c8c49abedb0c1b387b76bf..c81b4d233b04e87ddd16d30218ce544d25b265e5 100644 (file)
-#define LLAMA_API_INTERNAL
 #include "sampling.h"
-#include <random>
 
-struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
-    struct llama_sampling_context * result = new llama_sampling_context();
+#include "common.h"
 
-    result->params  = params;
-    result->grammar = nullptr;
+// the ring buffer works similarly to std::deque, but with a fixed capacity
+// TODO: deduplicate with llama-impl.h
+template<typename T>
+struct ring_buffer {
+    ring_buffer(size_t cap) : capacity(cap), data(cap) {}
 
-    // if there is a grammar, parse it
-    if (!params.grammar.empty()) {
-        result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
-
-        // will be empty (default) if there are parse errors
-        if (result->parsed_grammar.rules.empty()) {
-            fprintf(stderr, "%s: failed to parse grammar\n", __func__);
-            delete result;
-            return nullptr;
+    T & front() {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
         }
+        return data[first];
+    }
 
-        // Ensure that there is a "root" node.
-        if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) {
-            fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
-            delete result;
-            return nullptr;
+    const T & front() const {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
         }
+        return data[first];
+    }
 
-        std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
-
-        struct llama_grammar * grammar = llama_grammar_init(
-                grammar_rules.data(),
-                grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
-        if (grammar == nullptr) {
-            throw std::runtime_error("Failed to initialize llama_grammar");
+    T & back() {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
         }
-        result->grammar = grammar;
+        return data[pos];
     }
 
-    result->prev.resize(params.n_prev);
-
-    result->n_valid = 0;
-
-    llama_sampling_set_rng_seed(result, params.seed);
-
-    return result;
-}
-
-void llama_sampling_free(struct llama_sampling_context * ctx) {
-    if (ctx->grammar != NULL) {
-        llama_grammar_free(ctx->grammar);
+    const T & back() const {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        return data[pos];
     }
 
-    delete ctx;
-}
-
-void llama_sampling_reset(llama_sampling_context * ctx) {
-    if (ctx->grammar != NULL) {
-        llama_grammar_free(ctx->grammar);
-        ctx->grammar = NULL;
+    void push_back(const T & value) {
+        if (sz == capacity) {
+            // advance the start when buffer is full
+            first = (first + 1) % capacity;
+        } else {
+            sz++;
+        }
+        data[pos] = value;
+        pos = (pos + 1) % capacity;
     }
 
-    if (!ctx->parsed_grammar.rules.empty()) {
-        std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
+    T pop_front() {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        T value = data[first];
+        first = (first + 1) % capacity;
+        sz--;
+        return value;
+    }
 
-        struct llama_grammar * grammar = llama_grammar_init(
-                grammar_rules.data(),
-                grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
-        if (grammar == nullptr) {
-            throw std::runtime_error("Failed to initialize llama_grammar");
+    const T & rat(size_t i) const {
+        if (i >= sz) {
+            throw std::runtime_error("ring buffer: index out of bounds");
         }
-        ctx->grammar = grammar;
+        return data[(first + sz - i - 1) % capacity];
     }
 
-    std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
-    ctx->cur.clear();
-    ctx->n_valid = 0;
-}
+    std::vector<T> to_vector() const {
+        std::vector<T> result;
+        result.reserve(sz);
+        for (size_t i = 0; i < sz; i++) {
+            result.push_back(data[(first + i) % capacity]);
+        }
+        return result;
+    }
 
-void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
-    if (seed == LLAMA_DEFAULT_SEED) {
-        seed = std::random_device{}();
+    void clear() {
+        // here only reset the status of the buffer
+        sz = 0;
+        first = 0;
+        pos = 0;
     }
-    ctx->rng.seed(seed);
-}
 
-void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
-    if (dst->grammar) {
-        llama_grammar_free(dst->grammar);
-        dst->grammar = nullptr;
+    bool empty() const {
+        return sz == 0;
     }
 
-    if (src->grammar) {
-        dst->grammar = llama_grammar_copy(src->grammar);
+    size_t size() const {
+        return sz;
     }
 
-    dst->prev = src->prev;
-}
+    size_t capacity = 0;
+    size_t sz = 0;
+    size_t first = 0;
+    size_t pos = 0;
+    std::vector<T> data;
+};
 
-llama_token llama_sampling_last(llama_sampling_context * ctx) {
-    return ctx->prev.back();
-}
+struct gpt_sampler {
+    gpt_sampler_params params;
 
-std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) {
-    const int size = ctx_sampling->prev.size();
+    struct llama_sampler * grmr;
+    struct llama_sampler * chain;
 
-    n = std::min(n, size);
+    ring_buffer<llama_token> prev;
 
-    std::string result;
+    std::vector<llama_token_data> cur;
 
-    for (int i = size - n; i < size; i++) {
-        result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]);
-    }
+    llama_token_data_array cur_p;
 
-    return result;
-}
+    void set_logits(struct llama_context * ctx, int idx) {
+        const auto * logits = llama_get_logits_ith(ctx, idx);
+
+        const int n_vocab = llama_n_vocab(llama_get_model(ctx));
+
+        cur.resize(n_vocab);
+
+        for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+            cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
+        }
+
+        cur_p = { cur.data(), cur.size(), -1, false };
+    }
+};
 
-std::string llama_sampling_print(const llama_sampling_params & params) {
+std::string gpt_sampler_params::print() const {
     char result[1024];
 
     snprintf(result, sizeof(result),
             "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
             "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
             "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
-            params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
-            params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
-            params.mirostat, params.mirostat_eta, params.mirostat_tau);
+            penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
+            top_k, tfs_z, top_p, min_p, typ_p, temp,
+            mirostat, mirostat_eta, mirostat_tau);
 
     return std::string(result);
 }
 
-std::string llama_sampling_order_print(const llama_sampling_params & params) {
-    std::string result = "CFG -> Penalties ";
-    if (params.mirostat == 0) {
-        for (auto sampler_type : params.samplers_sequence) {
-            const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
-            if (!sampler_type_name.empty()) {
-                result += "-> " + sampler_type_name + " ";
+struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
+    llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
+
+    lparams.no_perf = false; // TODO: control via params
+
+    auto * result = new gpt_sampler {
+        /* .params = */ params,
+        /* .grmr   = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
+        /* .chain  = */ llama_sampler_chain_init(lparams),
+        /* .prev   = */ ring_buffer<llama_token>(params.n_prev),
+        /* .cur    = */ {},
+        /* .cur_p  = */ {},
+    };
+
+    llama_sampler_chain_add(result->chain,
+            llama_sampler_init_logit_bias(
+                llama_n_vocab(model),
+                params.logit_bias.size(),
+                params.logit_bias.data()));
+
+    llama_sampler_chain_add(result->chain,
+            llama_sampler_init_penalties(
+                llama_n_vocab  (model),
+                llama_token_eos(model),
+                llama_token_nl (model),
+                params.penalty_last_n,
+                params.penalty_repeat,
+                params.penalty_freq,
+                params.penalty_present,
+                params.penalize_nl,
+                params.ignore_eos));
+
+    if (params.temp > 0.0f) {
+        if (params.mirostat == 0) {
+            for (const auto & cnstr : params.samplers) {
+                switch (cnstr) {
+                    case GPT_SAMPLER_TYPE_TOP_K:
+                        llama_sampler_chain_add(result->chain, llama_sampler_init_top_k    (params.top_k));
+                        break;
+                    case GPT_SAMPLER_TYPE_TOP_P:
+                        llama_sampler_chain_add(result->chain, llama_sampler_init_top_p    (params.top_p, params.min_keep));
+                        break;
+                    case GPT_SAMPLER_TYPE_MIN_P:
+                        llama_sampler_chain_add(result->chain, llama_sampler_init_min_p    (params.min_p, params.min_keep));
+                        break;
+                    case GPT_SAMPLER_TYPE_TFS_Z:
+                        llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
+                        break;
+                    case GPT_SAMPLER_TYPE_TYPICAL_P:
+                        llama_sampler_chain_add(result->chain, llama_sampler_init_typical  (params.typ_p, params.min_keep));
+                        break;
+                    case GPT_SAMPLER_TYPE_TEMPERATURE:
+                        llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
+                        break;
+                    default:
+                        GGML_ASSERT(false && "unknown sampler type");
+                }
             }
+            llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
+            llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
+        } else if (params.mirostat == 1) {
+            llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
+            llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
+        } else if (params.mirostat == 2) {
+            llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
+            llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
+        } else {
+            GGML_ASSERT(false && "unknown mirostat version");
         }
     } else {
-        result += "-> mirostat ";
+        llama_sampler_chain_add(result->chain, llama_sampler_init_softmax());
+        llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
     }
 
     return result;
 }
 
-std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
-    switch (sampler_type) {
-        case llama_sampler_type::TOP_K:       return "top_k";
-        case llama_sampler_type::TFS_Z:       return "tfs_z";
-        case llama_sampler_type::TYPICAL_P:   return "typical_p";
-        case llama_sampler_type::TOP_P:       return "top_p";
-        case llama_sampler_type::MIN_P:       return "min_p";
-        case llama_sampler_type::TEMPERATURE: return "temperature";
-        default : return "";
+void gpt_sampler_free(struct gpt_sampler * gsmpl) {
+    if (gsmpl) {
+        llama_sampler_free(gsmpl->grmr);
+
+        llama_sampler_free(gsmpl->chain);
+
+        delete gsmpl;
     }
 }
 
-std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
-    std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
-        {"top_k",       llama_sampler_type::TOP_K},
-        {"top_p",       llama_sampler_type::TOP_P},
-        {"typical_p",   llama_sampler_type::TYPICAL_P},
-        {"min_p",       llama_sampler_type::MIN_P},
-        {"tfs_z",       llama_sampler_type::TFS_Z},
-        {"temperature", llama_sampler_type::TEMPERATURE}
-    };
+void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar) {
+    if (accept_grammar) {
+        llama_sampler_accept(gsmpl->grmr, token);
+    }
 
-    // since samplers names are written multiple ways
-    // make it ready for both system names and input names
-    std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
-        {"top-k",       llama_sampler_type::TOP_K},
-        {"top-p",       llama_sampler_type::TOP_P},
-        {"nucleus",     llama_sampler_type::TOP_P},
-        {"typical-p",   llama_sampler_type::TYPICAL_P},
-        {"typical",     llama_sampler_type::TYPICAL_P},
-        {"min-p",       llama_sampler_type::MIN_P},
-        {"tfs-z",       llama_sampler_type::TFS_Z},
-        {"tfs",         llama_sampler_type::TFS_Z},
-        {"temp",        llama_sampler_type::TEMPERATURE}
-    };
+    llama_sampler_accept(gsmpl->chain, token);
 
-    std::vector<llama_sampler_type> sampler_types;
-    sampler_types.reserve(names.size());
-    for (const auto & name : names)
-    {
-        auto sampler_item = sampler_canonical_name_map.find(name);
-        if (sampler_item != sampler_canonical_name_map.end())
-        {
-            sampler_types.push_back(sampler_item->second);
-        }
-        else
-        {
-            if (allow_alt_names)
-            {
-                sampler_item = sampler_alt_name_map.find(name);
-                if (sampler_item != sampler_alt_name_map.end())
-                {
-                    sampler_types.push_back(sampler_item->second);
-                }
-            }
-        }
-    }
-    return sampler_types;
+    gsmpl->prev.push_back(token);
 }
 
-std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string) {
-    std::unordered_map<char, llama_sampler_type> sampler_name_map {
-        {'k', llama_sampler_type::TOP_K},
-        {'p', llama_sampler_type::TOP_P},
-        {'y', llama_sampler_type::TYPICAL_P},
-        {'m', llama_sampler_type::MIN_P},
-        {'f', llama_sampler_type::TFS_Z},
-        {'t', llama_sampler_type::TEMPERATURE}
-    };
+void gpt_sampler_reset(struct gpt_sampler * gsmpl) {
+    llama_sampler_reset(gsmpl->grmr);
 
-    std::vector<llama_sampler_type> sampler_types;
-    sampler_types.reserve(names_string.size());
-    for (const auto & c : names_string) {
-        const auto sampler_item = sampler_name_map.find(c);
-        if (sampler_item != sampler_name_map.end()) {
-            sampler_types.push_back(sampler_item->second);
-        }
-    }
-    return sampler_types;
+    llama_sampler_reset(gsmpl->chain);
 }
 
-// no reasons to expose this function in header
-static void sampler_queue(
-                   struct llama_context * ctx_main,
-            const llama_sampling_params & params,
-                 llama_token_data_array & cur_p,
-                                 size_t   min_keep) {
-    const float         temp              = params.temp;
-    const float         dynatemp_range    = params.dynatemp_range;
-    const float         dynatemp_exponent = params.dynatemp_exponent;
-    const int32_t       top_k             = params.top_k;
-    const float         top_p             = params.top_p;
-    const float         min_p             = params.min_p;
-    const float         tfs_z             = params.tfs_z;
-    const float         typical_p         = params.typical_p;
-    const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
-
-    for (auto sampler_type : samplers_sequence) {
-        switch (sampler_type) {
-            case llama_sampler_type::TOP_K    : llama_sample_top_k    (ctx_main, &cur_p, top_k,     min_keep); break;
-            case llama_sampler_type::TFS_Z    : llama_sample_tail_free(ctx_main, &cur_p, tfs_z,     min_keep); break;
-            case llama_sampler_type::TYPICAL_P: llama_sample_typical  (ctx_main, &cur_p, typical_p, min_keep); break;
-            case llama_sampler_type::TOP_P    : llama_sample_top_p    (ctx_main, &cur_p, top_p,     min_keep); break;
-            case llama_sampler_type::MIN_P    : llama_sample_min_p    (ctx_main, &cur_p, min_p,     min_keep); break;
-            case llama_sampler_type::TEMPERATURE:
-                if (dynatemp_range > 0) {
-                    float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
-                    float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
-                    llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
-                } else {
-                    llama_sample_temp(ctx_main, &cur_p, temp);
-                }
-                break;
-            default : break;
-        }
-    }
+struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
+    return new gpt_sampler {
+        /* .params = */ gsmpl->params,
+        /* .grmr   = */ llama_sampler_clone(gsmpl->grmr),
+        /* .chain  = */ llama_sampler_clone(gsmpl->chain),
+        /* .prev   = */ gsmpl->prev,
+        /* .cur    = */ gsmpl->cur,
+        /* .cur_p  = */ gsmpl->cur_p,
+    };
 }
 
-static llama_token llama_sampling_sample_impl(
-                  struct llama_sampling_context * ctx_sampling,
-                  struct llama_context * ctx_main,
-                  struct llama_context * ctx_cfg,
-                  const int idx,
-                  bool is_resampling) {
-    const llama_sampling_params & params = ctx_sampling->params;
-
-    const float   temp            = params.temp;
-    const int     mirostat        = params.mirostat;
-    const float   mirostat_tau    = params.mirostat_tau;
-    const float   mirostat_eta    = params.mirostat_eta;
-
-    std::vector<float> original_logits;
-    auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
-    if (ctx_sampling->grammar != NULL && !is_resampling) {
-        GGML_ASSERT(!original_logits.empty());
-    }
-    llama_token id = 0;
-
-    if (temp < 0.0) {
-        // greedy sampling, with probs
-        llama_sample_softmax(ctx_main, &cur_p);
-        id = cur_p.data[0].id;
-    } else if (temp == 0.0) {
-        // greedy sampling, no probs
-        id = llama_sample_token_greedy(ctx_main, &cur_p);
-    } else {
-        if (mirostat == 1) {
-            const int mirostat_m = 100;
-            llama_sample_temp(ctx_main, &cur_p, temp);
-            id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
-        } else if (mirostat == 2) {
-            llama_sample_temp(ctx_main, &cur_p, temp);
-            id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
-        } else {
-            // temperature sampling
-            size_t min_keep = std::max(1, params.min_keep);
-
-            sampler_queue(ctx_main, params, cur_p, min_keep);
+void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl) {
+    // TODO: measure grammar performance
 
-            id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
+    if (gsmpl) {
+        llama_perf_print(gsmpl->chain, LLAMA_PERF_TYPE_SAMPLER_CHAIN);
+    }
+    if (ctx) {
+        llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
+    }
+}
 
-            //{
-            //    const int n_top = 10;
-            //    LOG("top %d candidates:\n", n_top);
+llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
+    gsmpl->set_logits(ctx, idx);
 
-            //    for (int i = 0; i < n_top; i++) {
-            //        const llama_token id = cur_p.data[i].id;
-            //        (void)id; // To avoid a warning that id is unused when logging is disabled.
-            //        LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
-            //    }
-            //}
+    auto & grmr  = gsmpl->grmr;
+    auto & chain = gsmpl->chain;
+    auto & cur_p = gsmpl->cur_p; // initialized by set_logits
 
-            //LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str());
-        }
+    if (grammar_first) {
+        llama_sampler_apply(grmr, &cur_p);
     }
 
-    if (ctx_sampling->grammar != NULL && !is_resampling) {
-        // Get a pointer to the logits
-        float * logits = llama_get_logits_ith(ctx_main, idx);
+    llama_sampler_apply(chain, &cur_p);
 
-        // Create an array with a single token data element for the sampled id
-        llama_token_data single_token_data = {id, logits[id], 0.0f};
-        llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
+    GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
 
-        // Apply grammar constraints to the single token
-        llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);
+    const llama_token id = cur_p.data[cur_p.selected].id;
 
-        // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
-        bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
+    if (grammar_first) {
+        return id;
+    }
 
-        // If the token is not valid according to the grammar, perform resampling
-        if (!is_valid) {
-            LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());
+    // check if it the sampled token fits the grammar
+    {
+        llama_token_data       single_token_data       = { id, 1.0f, 0.0f };
+        llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
 
-            // Restore logits from the copy
-            std::copy(original_logits.begin(), original_logits.end(), logits);
+        llama_sampler_apply(grmr, &single_token_data_array);
 
-            return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true);
+        const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
+        if (is_valid) {
+            return id;
         }
     }
 
-    ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size;
+    // resampling:
+    // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
+    gsmpl->set_logits(ctx, idx);
 
-    return id;
-}
+    llama_sampler_apply(grmr,  &cur_p);
+    llama_sampler_apply(chain, &cur_p);
 
-static llama_token_data_array llama_sampling_prepare_impl(
-                  struct llama_sampling_context * ctx_sampling,
-                  struct llama_context * ctx_main,
-                  struct llama_context * ctx_cfg,
-                  const int idx,
-                  bool apply_grammar,
-                  std::vector<float> * original_logits) {
-    const llama_sampling_params & params = ctx_sampling->params;
+    GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
 
-    const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
+    return cur_p.data[cur_p.selected].id;
+}
 
-    const int32_t penalty_last_n  = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
-    const float   penalty_repeat  = params.penalty_repeat;
-    const float   penalty_freq    = params.penalty_freq;
-    const float   penalty_present = params.penalty_present;
+// helpers
 
-    const bool    penalize_nl     = params.penalize_nl;
+llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {
+    return &gsmpl->cur_p;
+}
 
-    auto & prev = ctx_sampling->prev;
-    auto & cur  = ctx_sampling->cur;
+llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) {
+    return gsmpl->prev.rat(0);
+}
 
-    // Get a pointer to the logits
-    float * logits = llama_get_logits_ith(ctx_main, idx);
+std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
+    std::string result = "\tlogits ";
 
-    if (ctx_sampling->grammar != NULL && !apply_grammar) {
-        GGML_ASSERT(original_logits != NULL);
-        // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
-        *original_logits = {logits, logits + n_vocab};
+    for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
+        const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
+        result += std::string("-> ") + llama_sampler_name(smpl) + " ";
     }
 
-    // apply params.logit_bias map
-    for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
-        logits[it->first] += it->second;
+    return result;
+}
+
+std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) {
+    n = std::min(n, (int) gsmpl->prev.size());
+
+    if (n <= 0) {
+        return "";
     }
 
-    if (ctx_cfg) {
-        float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
-        llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
+    std::string result;
+    result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab
+
+    for (int i = n - 1; i >= 0; i--) {
+        const llama_token id = gsmpl->prev.rat(i);
+
+        GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");
+
+        result += llama_token_to_piece(ctx_main, id);
     }
 
-    cur.resize(n_vocab);
+    return result;
+}
 
-    for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-        cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
+char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr) {
+    switch (cnstr) {
+        case GPT_SAMPLER_TYPE_TOP_K:       return 'k';
+        case GPT_SAMPLER_TYPE_TFS_Z:       return 'f';
+        case GPT_SAMPLER_TYPE_TYPICAL_P:   return 'y';
+        case GPT_SAMPLER_TYPE_TOP_P:       return 'p';
+        case GPT_SAMPLER_TYPE_MIN_P:       return 'm';
+        case GPT_SAMPLER_TYPE_TEMPERATURE: return 't';
+        default : return '?';
     }
+}
 
-    llama_token_data_array cur_p = { cur.data(), cur.size(), false };
+std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr) {
+    switch (cnstr) {
+        case GPT_SAMPLER_TYPE_TOP_K:       return "top_k";
+        case GPT_SAMPLER_TYPE_TFS_Z:       return "tfs_z";
+        case GPT_SAMPLER_TYPE_TYPICAL_P:   return "typ_p";
+        case GPT_SAMPLER_TYPE_TOP_P:       return "top_p";
+        case GPT_SAMPLER_TYPE_MIN_P:       return "min_p";
+        case GPT_SAMPLER_TYPE_TEMPERATURE: return "temperature";
+        default : return "";
+    }
+}
 
-    // apply penalties
-    const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
-    const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
-    if (penalty_tokens_used_size) {
-        const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
+std::vector<gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
+    std::unordered_map<std::string, gpt_sampler_type> sampler_canonical_name_map {
+        { "top_k",       GPT_SAMPLER_TYPE_TOP_K },
+        { "top_p",       GPT_SAMPLER_TYPE_TOP_P },
+        { "typ_p",       GPT_SAMPLER_TYPE_TYPICAL_P },
+        { "min_p",       GPT_SAMPLER_TYPE_MIN_P },
+        { "tfs_z",       GPT_SAMPLER_TYPE_TFS_Z },
+        { "temperature", GPT_SAMPLER_TYPE_TEMPERATURE },
+    };
 
-        llama_sample_repetition_penalties(ctx_main, &cur_p,
-                penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
-                penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
+    // since samplers names are written multiple ways
+    // make it ready for both system names and input names
+    std::unordered_map<std::string, gpt_sampler_type> sampler_alt_name_map {
+        { "top-k",       GPT_SAMPLER_TYPE_TOP_K },
+        { "top-p",       GPT_SAMPLER_TYPE_TOP_P },
+        { "nucleus",     GPT_SAMPLER_TYPE_TOP_P },
+        { "typical-p",   GPT_SAMPLER_TYPE_TYPICAL_P },
+        { "typical",     GPT_SAMPLER_TYPE_TYPICAL_P },
+        { "typ-p",       GPT_SAMPLER_TYPE_TYPICAL_P },
+        { "typ",         GPT_SAMPLER_TYPE_TYPICAL_P },
+        { "min-p",       GPT_SAMPLER_TYPE_MIN_P },
+        { "tfs-z",       GPT_SAMPLER_TYPE_TFS_Z },
+        { "tfs",         GPT_SAMPLER_TYPE_TFS_Z },
+        { "temp",        GPT_SAMPLER_TYPE_TEMPERATURE },
+    };
+
+    std::vector<gpt_sampler_type> samplers;
+    samplers.reserve(names.size());
 
-        if (!penalize_nl) {
-            for (size_t idx = 0; idx < cur_p.size; idx++) {
-                if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
-                    cur_p.data[idx].logit = nl_logit;
-                    break;
+    for (const auto & name : names) {
+        auto sampler = sampler_canonical_name_map.find(name);
+        if (sampler != sampler_canonical_name_map.end()) {
+            samplers.push_back(sampler->second);
+        } else {
+            if (allow_alt_names) {
+                sampler = sampler_alt_name_map.find(name);
+                if (sampler != sampler_alt_name_map.end()) {
+                    samplers.push_back(sampler->second);
                 }
             }
         }
     }
 
-    // apply grammar checks before sampling logic
-    if (apply_grammar && ctx_sampling->grammar != NULL) {
-        llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
-    }
-
-    return cur_p;
+    return samplers;
 }
 
-llama_token llama_sampling_sample(
-                  struct llama_sampling_context * ctx_sampling,
-                  struct llama_context * ctx_main,
-                  struct llama_context * ctx_cfg,
-                  const int idx) {
-    // Call the implementation function with is_resampling set to false by default
-    return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false);
-}
-
-llama_token_data_array llama_sampling_prepare(
-                  struct llama_sampling_context * ctx_sampling,
-                  struct llama_context * ctx_main,
-                  struct llama_context * ctx_cfg,
-                  const int idx,
-                  bool apply_grammar,
-                  std::vector<float> * original_logits) {
-    return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
-}
+std::vector<gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars) {
+    std::unordered_map<char, gpt_sampler_type> sampler_name_map {
+        { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_K),       GPT_SAMPLER_TYPE_TOP_K },
+        { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TFS_Z),       GPT_SAMPLER_TYPE_TFS_Z },
+        { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TYPICAL_P),   GPT_SAMPLER_TYPE_TYPICAL_P },
+        { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_P),       GPT_SAMPLER_TYPE_TOP_P },
+        { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_MIN_P),       GPT_SAMPLER_TYPE_MIN_P },
+        { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TEMPERATURE), GPT_SAMPLER_TYPE_TEMPERATURE }
+    };
 
-void llama_sampling_accept(
-        struct llama_sampling_context * ctx_sampling,
-        struct llama_context * ctx_main,
-        llama_token id,
-        bool apply_grammar) {
-    ctx_sampling->prev.erase(ctx_sampling->prev.begin());
-    ctx_sampling->prev.push_back(id);
+    std::vector<gpt_sampler_type> samplers;
+    samplers.reserve(chars.size());
 
-    if (ctx_sampling->grammar != NULL && apply_grammar) {
-        llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
+    for (const auto & c : chars) {
+        const auto sampler = sampler_name_map.find(c);
+        if (sampler != sampler_name_map.end()) {
+            samplers.push_back(sampler->second);
+        }
     }
+
+    return samplers;
 }
index eeaa53b8bcd008e2c5d7cfb330505de2706a5c05..654e0c513904d86282bce22059ff501981f47514 100644 (file)
 
 #include "llama.h"
 
-#include "grammar-parser.h"
-
-#include <random>
 #include <string>
-#include <unordered_map>
 #include <vector>
 
-// sampler types
-enum class llama_sampler_type : char {
-    TOP_K       = 'k',
-    TOP_P       = 'p',
-    MIN_P       = 'm',
-    TFS_Z       = 'f',
-    TYPICAL_P   = 'y',
-    TEMPERATURE = 't'
+enum gpt_sampler_type {
+    GPT_SAMPLER_TYPE_NONE        = 0,
+    GPT_SAMPLER_TYPE_TOP_K       = 1,
+    GPT_SAMPLER_TYPE_TOP_P       = 2,
+    GPT_SAMPLER_TYPE_MIN_P       = 3,
+    GPT_SAMPLER_TYPE_TFS_Z       = 4,
+    GPT_SAMPLER_TYPE_TYPICAL_P   = 5,
+    GPT_SAMPLER_TYPE_TEMPERATURE = 6,
 };
 
 // sampling parameters
-typedef struct llama_sampling_params {
-    int32_t     n_prev                = 64;                 // number of previous tokens to remember
-    int32_t     n_probs               = 0;                  // if greater than 0, output the probabilities of top n_probs tokens.
-    int32_t     min_keep              = 0;                  // 0 = disabled, otherwise samplers should return at least min_keep tokens
-    int32_t     top_k                 = 40;                 // <= 0 to use vocab size
-    float       top_p                 = 0.95f;              // 1.0 = disabled
-    float       min_p                 = 0.05f;              // 0.0 = disabled
-    float       tfs_z                 = 1.00f;              // 1.0 = disabled
-    float       typical_p             = 1.00f;              // 1.0 = disabled
-    float       temp                  = 0.80f;              // <= 0.0 to sample greedily, 0.0 to not output probabilities
-    float       dynatemp_range        = 0.00f;              // 0.0 = disabled
-    float       dynatemp_exponent     = 1.00f;              // controls how entropy maps to temperature in dynamic temperature sampler
-    int32_t     penalty_last_n        = 64;                 // last n tokens to penalize (0 = disable penalty, -1 = context size)
-    float       penalty_repeat        = 1.00f;              // 1.0 = disabled
-    float       penalty_freq          = 0.00f;              // 0.0 = disabled
-    float       penalty_present       = 0.00f;              // 0.0 = disabled
-    int32_t     mirostat              = 0;                  // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
-    float       mirostat_tau          = 5.00f;              // target entropy
-    float       mirostat_eta          = 0.10f;              // learning rate
-    bool        penalize_nl           = false;              // consider newlines as a repeatable token
-    uint32_t    seed                  = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
-
-    std::vector<llama_sampler_type> samplers_sequence = {
-        llama_sampler_type::TOP_K,
-        llama_sampler_type::TFS_Z,
-        llama_sampler_type::TYPICAL_P,
-        llama_sampler_type::TOP_P,
-        llama_sampler_type::MIN_P,
-        llama_sampler_type::TEMPERATURE
+struct gpt_sampler_params {
+    uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
+
+    int32_t n_prev            = 64;    // number of previous tokens to remember
+    int32_t n_probs           = 0;     // if greater than 0, output the probabilities of top n_probs tokens.
+    int32_t min_keep          = 0;     // 0 = disabled, otherwise samplers should return at least min_keep tokens
+    int32_t top_k             = 40;    // <= 0 to use vocab size
+    float   top_p             = 0.95f; // 1.0 = disabled
+    float   min_p             = 0.05f; // 0.0 = disabled
+    float   tfs_z             = 1.00f; // 1.0 = disabled
+    float   typ_p             = 1.00f; // typical_p, 1.0 = disabled
+    float   temp              = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
+    float   dynatemp_range    = 0.00f; // 0.0 = disabled
+    float   dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
+    int32_t penalty_last_n    = 64;    // last n tokens to penalize (0 = disable penalty, -1 = context size)
+    float   penalty_repeat    = 1.00f; // 1.0 = disabled
+    float   penalty_freq      = 0.00f; // 0.0 = disabled
+    float   penalty_present   = 0.00f; // 0.0 = disabled
+    int32_t mirostat          = 0;     // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+    float   mirostat_tau      = 5.00f; // target entropy
+    float   mirostat_eta      = 0.10f; // learning rate
+    bool    penalize_nl       = false; // consider newlines as a repeatable token
+    bool    ignore_eos        = false;
+
+    std::vector<enum gpt_sampler_type> samplers = {
+        GPT_SAMPLER_TYPE_TOP_K,
+        GPT_SAMPLER_TYPE_TFS_Z,
+        GPT_SAMPLER_TYPE_TYPICAL_P,
+        GPT_SAMPLER_TYPE_TOP_P,
+        GPT_SAMPLER_TYPE_MIN_P,
+        GPT_SAMPLER_TYPE_TEMPERATURE
     };
 
-    std::string grammar;  // optional BNF-like grammar to constrain sampling
-
-    // Classifier-Free Guidance
-    // https://arxiv.org/abs/2306.17806
-    std::string cfg_negative_prompt; // string to help guidance
-    float       cfg_scale     = 1.f; // how strong is guidance
-
-    std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
-
-    std::vector<llama_token> penalty_prompt_tokens;
-    bool                     use_penalty_prompt_tokens = false;
-} llama_sampling_params;
-
-// general sampler context
-// TODO: move to llama.h
-struct llama_sampling_context {
-    // parameters that will be used for sampling
-    llama_sampling_params params;
-
-    // mirostat sampler state
-    float mirostat_mu;
+    std::string grammar; // optional BNF-like grammar to constrain sampling
 
-    llama_grammar * grammar;
+    std::vector<llama_logit_bias> logit_bias; // logit biases to apply
 
-    // internal
-    grammar_parser::parse_state parsed_grammar;
+    // print the parameters into a string
+    std::string print() const;
+};
 
-    // TODO: replace with ring-buffer
-    std::vector<llama_token>      prev;
-    std::vector<llama_token_data> cur;
-    size_t n_valid; // Number of correct top tokens with correct probabilities.
+// gpt_sampler extends llama_sampler with additional functionality:
+//
+//  - grammar support
+//  - custom sampler logic based on the parameters
+//  - history of the last accepted tokens
+//  - performance metrics
+//
+// This goal is to have a common implementation of the sampling logic shared across the examples.
+// For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
+// complex (top-k, top-p, etc).
+//
+// Another example is related to the grammar. In general, the grammar constraints applied on the full
+// vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
+// token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
+// grammar constraints are applied to the full vocabulary and the token is resampled.
+//
+// The gpt_sampler also maintains a container with the last accepted tokens. In the future, this can
+// be moved into the core llama library.
+//
+// For convenience, the gpt_sampler also maintains a container with the current candidate tokens.
+// This can be used to access the probabilities of the rest of the non-sampled tokens.
+//
+// TODO: measure grammar performance
+//
 
-    std::mt19937 rng;
-};
+struct gpt_sampler;
 
-#include "common.h"
+// llama_sampler API overloads
 
-// Create a new sampling context instance.
-struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
+struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params);
 
-void llama_sampling_free(struct llama_sampling_context * ctx);
+void gpt_sampler_free(struct gpt_sampler * gsmpl);
 
-// Reset the sampler context
-// - clear prev tokens
-// - reset grammar
-void llama_sampling_reset(llama_sampling_context * ctx);
+// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
+void                 gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar);
+void                 gpt_sampler_reset (struct gpt_sampler * gsmpl);
+struct gpt_sampler * gpt_sampler_clone (struct gpt_sampler * gsmpl);
 
-// Set the sampler seed
-void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
+// arguments can be nullptr to skip printing
+void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl);
 
-// Copy the sampler context
-void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
+// extended sampling implementation:
+//
+// - set logits
+// - apply the configured sampler chain
+// - check if the token fits the grammar (if any)
+// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
+//
+// if grammar_first is true, the grammar is applied before the samplers (slower)
+// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
+//
+llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
 
-// Get the last sampled token
-llama_token llama_sampling_last(llama_sampling_context * ctx);
+// helpers
 
-// Get a string representation of the last sampled tokens
-std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
+// access the internal list of current candidate tokens
+llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
 
-// Print sampling parameters into a string
-std::string llama_sampling_print(const llama_sampling_params & params);
+// get the last accepted token
+llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
 
-// Print sampling order into a string
-std::string llama_sampling_order_print(const llama_sampling_params & params);
+// print the sampler chain into a string
+std::string gpt_sampler_print(const struct gpt_sampler * gsmpl);
 
-std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);
+// get a string representation of the last accepted tokens
+std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n);
 
-std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
-std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string);
+char        gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr);
+std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr);
 
-// this is a common sampling function used across the examples for convenience
-// it can serve as a starting point for implementing your own sampling function
-// Note: When using multiple sequences, it is the caller's responsibility to call
-//       llama_sampling_reset when a sequence ends
-//
-// required:
-//  - ctx_main:     context to use for sampling
-//  - ctx_sampling: sampling-specific context
-//
-// optional:
-//  - ctx_cfg:      context to use for classifier-free guidance
-//  - idx:          sample from llama_get_logits_ith(ctx, idx)
-//
-// returns:
-//  - token:      sampled token
-//  - candidates: vector of candidate tokens
-//
-llama_token llama_sampling_sample(
-        struct llama_sampling_context * ctx_sampling,
-        struct llama_context * ctx_main,
-        struct llama_context * ctx_cfg,
-        int idx = -1);
-
-// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
-llama_token_data_array llama_sampling_prepare(
-        struct llama_sampling_context * ctx_sampling,
-        struct llama_context * ctx_main,
-        struct llama_context * ctx_cfg,
-        int idx = 0,
-        bool apply_grammar = true,
-        std::vector<float> * original_logits = nullptr);
-
-void llama_sampling_accept(
-        struct llama_sampling_context * ctx_sampling,
-        struct llama_context * ctx_main,
-        llama_token id,
-        bool apply_grammar);
+std::vector<enum gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
+std::vector<enum gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars);
index 25a950ea59a8ce6e5f2fcf98ceaf1f5aad833912..b043c74cc4954371a20d94fdeb060adb3d89977f 100644 (file)
@@ -210,7 +210,8 @@ int main(int argc, char ** argv) {
         }
     }
 
-    llama_print_timings(ctx);
+    LOG_TEE("\n");
+    llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
 
     llama_batch_free(batch);
 
index 616494d2d841d099af6332f34765121eff81a951..4bc2bbf2c1570845ba18db0db3c7dbf6ca46497f 100644 (file)
@@ -27,7 +27,6 @@ guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), mo
     print("Failed to load model")
     exit(1)
 }
-
 defer {
     llama_free_model(model)
 }
@@ -37,7 +36,6 @@ var tokens = tokenize(text: prompt, add_bos: true)
 let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)
 
 var context_params = llama_context_default_params()
-context_params.seed = 1234
 context_params.n_ctx = n_kv_req
 context_params.n_batch = UInt32(max(n_len, n_parallel))
 context_params.n_threads = 8
@@ -48,11 +46,26 @@ guard context != nil else {
     print("Failed to initialize context")
     exit(1)
 }
-
 defer {
     llama_free(context)
 }
 
+var sparams = llama_sampler_chain_default_params()
+
+let smpl = llama_sampler_chain_init(sparams)
+guard smpl != nil else {
+    print("Failed to initialize sampling")
+    exit(1)
+}
+defer {
+    llama_sampler_free(smpl)
+}
+
+llama_sampler_chain_add(smpl, llama_sampler_init_top_k(40));
+llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1));
+llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.4));
+llama_sampler_chain_add(smpl, llama_sampler_init_dist (1234));
+
 let n_ctx = llama_n_ctx(context)
 
 print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n")
@@ -125,32 +138,9 @@ while n_cur <= n_len {
             continue
         }
 
-        var n_vocab = llama_n_vocab(model)
-        var logits = llama_get_logits_ith(context, i_batch[i])
-
-        var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab))
-
-        for token_id in 0 ..< n_vocab {
-            candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
-        }
-
-        var candidates_p: llama_token_data_array = .init(
-            data: &candidates,
-            size: candidates.count,
-            sorted: false
-        )
-
-        let top_k: Int32 = 40
-        let top_p: Float = 0.9
-        let temp: Float = 0.4
-
-        llama_sample_top_k(context, &candidates_p, top_k, 1)
-        llama_sample_top_p(context, &candidates_p, top_p, 1)
-        llama_sample_temp(context, &candidates_p, temp)
-
-        let new_token_id = llama_sample_token(context, &candidates_p)
+        let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
 
-        // const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
+        llama_sampler_accept(smpl, new_token_id)
 
         // is it an end of stream? -> mark the stream as finished
         if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
@@ -210,9 +200,10 @@ if n_parallel > 1 {
 
 let t_main_end = ggml_time_us()
 
-print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n")
+print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n\n")
 
-llama_print_timings(context)
+llama_perf_print(UnsafeRawPointer(context), LLAMA_PERF_TYPE_CONTEXT)
+llama_perf_print(UnsafeRawPointer(smpl),    LLAMA_PERF_TYPE_SAMPLER_CHAIN)
 
 private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
     let utf8Count = text.utf8.count
index 53fbfb0a8cf2ae88775c50506b2c6bef9d40e2af..f321f61047ad5ab5ec17eac065d5d6fb4e098a27 100644 (file)
@@ -2,7 +2,6 @@
 #include "llama.h"
 
 #include <algorithm>
-#include <cmath>
 #include <cstdio>
 #include <string>
 #include <vector>
@@ -65,6 +64,15 @@ int main(int argc, char ** argv) {
 
     llama_context * ctx = llama_new_context_with_model(model, ctx_params);
 
+    auto sparams = llama_sampler_chain_default_params();
+
+    llama_sampler * smpl = llama_sampler_chain_init(sparams);
+
+    llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sparams.top_k));
+    llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sparams.top_p, params.sparams.min_keep));
+    llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sparams.temp));
+    llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sparams.seed));
+
     if (ctx == NULL) {
         fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
         return 1;
@@ -164,29 +172,9 @@ int main(int argc, char ** argv) {
                 continue;
             }
 
-            auto   n_vocab = llama_n_vocab(model);
-            auto * logits  = llama_get_logits_ith(ctx, i_batch[i]);
-
-            std::vector<llama_token_data> candidates;
-            candidates.reserve(n_vocab);
-
-            for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-                candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
-            }
+            const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
 
-            llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-
-            const int   top_k = 40;
-            const float top_p = 0.9f;
-            const float temp  = 0.4f;
-
-            llama_sample_top_k(ctx, &candidates_p, top_k, 1);
-            llama_sample_top_p(ctx, &candidates_p, top_p, 1);
-            llama_sample_temp (ctx, &candidates_p, temp);
-
-            const llama_token new_token_id = llama_sample_token(ctx, &candidates_p);
-
-            //const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
+            llama_sampler_accept(smpl, new_token_id);
 
             // is it an end of generation? -> mark the stream as finished
             if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
@@ -244,12 +232,15 @@ int main(int argc, char ** argv) {
     LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
             __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
 
-    llama_print_timings(ctx);
+    LOG_TEE("\n");
+    llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN);
+    llama_perf_print(ctx,  LLAMA_PERF_TYPE_CONTEXT);
 
     fprintf(stderr, "\n");
 
     llama_batch_free(batch);
 
+    llama_sampler_free(smpl);
     llama_free(ctx);
     llama_free_model(model);
 
index b05aa006e7da51fcce9bde0984a69c8418265761..e5e0872b1ba4a4d8eaaf3d14c90d8732562bb8db 100644 (file)
@@ -90,13 +90,7 @@ int main(int argc, char ** argv) {
 
     print_build_info();
 
-    if (params.seed == LLAMA_DEFAULT_SEED) {
-        params.seed = time(NULL);
-    }
-
-    fprintf(stderr, "%s: seed  = %u\n", __func__, params.seed);
-
-    std::mt19937 rng(params.seed);
+    LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
 
     llama_backend_init();
     llama_numa_init(params.numa);
@@ -313,8 +307,10 @@ int main(int argc, char ** argv) {
         if (notArray) fprintf(stdout, "\n}\n");
     }
 
+    LOG_TEE("\n");
+    llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
+
     // clean up
-    llama_print_timings(ctx);
     llama_batch_free(batch);
     llama_free(ctx);
     llama_free_model(model);
index 5e89988e2beda3f117113f3f4a2933a97e442b37..aea15c864ea93fae51d5651f2f5337d4b8d12b32 100644 (file)
@@ -151,8 +151,6 @@ int main(int argc, char ** argv) {
 
     print_build_info();
 
-    std::mt19937 rng(params.seed);
-
     llama_backend_init();
     llama_numa_init(params.numa);
 
@@ -183,7 +181,8 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    llama_print_timings(ctx);
+    LOG_TEE("\n");
+    llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
 
     llama_free(ctx);
     llama_free_model(model);
index 48a705e15cea996387cdb6c5c1c0505da9a7d7f8..7493af9d3aec32aef2c9e3330eb75d65b516f634 100644 (file)
@@ -1,9 +1,5 @@
-#define LLAMA_API_INTERNAL
-
-#include "grammar-parser.h"
-#include "ggml.h"
-#include "llama.h"
 #include "unicode.h"
+#include "llama-grammar.h"
 
 #include <cstdio>
 #include <cstdlib>
 #include <string>
 #include <vector>
 
-static bool llama_sample_grammar_string(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
-    auto decoded = decode_utf8(input_str, {});
-    const auto & code_points = decoded.first;
+static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
+    const auto cpts = unicode_cpts_from_utf8(input_str);
 
     const llama_grammar_rules  & rules      = llama_grammar_get_rules (grammar);
-          llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
+          llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
 
     size_t pos = 0;
-    for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
-        const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
+    for (const auto & cpt : cpts) {
+        const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
 
-        llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
+        llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
 
-        if (cur_stacks.empty()) {
+        if (stacks_cur.empty()) {
             error_pos = pos;
-            error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
-            cur_stacks = prev_stacks;
+            error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
+            stacks_cur = stacks_prev;
             return false;
         }
         ++pos;
     }
 
-    for (const auto & stack : cur_stacks) {
+    for (const auto & stack : stacks_cur) {
         if (stack.empty()) {
             return true;
         }
@@ -85,27 +80,7 @@ int main(int argc, char** argv) {
         grammar_str = buffer.str();
     }
 
-    // Parse the GBNF grammar
-    auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
-
-    // will be empty (default) if there are parse errors
-    if (parsed_grammar.rules.empty()) {
-        fprintf(stdout, "%s: failed to parse grammar\n", __func__);
-        return 1;
-    }
-
-    // Ensure that there is a "root" node.
-    if (parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()) {
-        fprintf(stdout, "%s: grammar does not contain a 'root' symbol\n", __func__);
-        return 1;
-    }
-
-    std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
-
-    // Create the LLAMA grammar
-    auto grammar = llama_grammar_init(
-            grammar_rules.data(),
-            grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+    llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
     if (grammar == nullptr) {
         throw std::runtime_error("Failed to initialize llama_grammar");
     }
@@ -122,7 +97,7 @@ int main(int argc, char** argv) {
     // Validate the input string against the grammar
     size_t error_pos;
     std::string error_msg;
-    bool is_valid = llama_sample_grammar_string(grammar, input_str, error_pos, error_msg);
+    bool is_valid = llama_grammar_validate(grammar, input_str, error_pos, error_msg);
 
     if (is_valid) {
         fprintf(stdout, "Input string is valid according to the grammar.\n");
@@ -131,7 +106,7 @@ int main(int argc, char** argv) {
     }
 
     // Clean up
-    llama_grammar_free(grammar);
+    llama_grammar_free_impl(grammar);
 
     return 0;
 }
index 2c61c2e1eb3bc80002ead3172fa0052f11f133c2..4e801c69d2f06bbca1e68982178abc0193b903a6 100644 (file)
@@ -9,7 +9,7 @@
 static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
     std::vector<std::vector<float>> result;
 
-    const llama_model * mdl = llama_get_model(ctx);
+    const llama_model * model = llama_get_model(ctx);
 
     llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
 
@@ -18,16 +18,16 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
 
         const std::string input_string = instruction + sentences[i];
 
-        std::vector<llama_token> inputs = llama_tokenize(mdl, input_string, true, false);
+        std::vector<llama_token> inputs = llama_tokenize(model, input_string, true, false);
 
         const int32_t n_toks = inputs.size();
 
         // GritLM seems to have EOS = ""
         // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
-        // inputs.push_back(llama_token_eos(mdl));
+        // inputs.push_back(llama_token_eos(model));
 
         // we want to ignore instruction tokens for mean pooling
-        const int32_t n_inst = llama_tokenize(mdl, instruction, true, false).size();
+        const int32_t n_inst = llama_tokenize(model, instruction, true, false).size();
 
 #ifdef GRIT_DEBUG
         // debug tokens - should be matching as referenced in the GritLM sample
@@ -51,7 +51,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
         llama_decode(ctx, batch);
 
         // get embedding dimensions
-        uint64_t n_embd = llama_n_embd(mdl);
+        uint64_t n_embd = llama_n_embd(model);
 
         // allocate embedding output
         std::vector<float> emb_unorm(n_embd, 0.0f);
@@ -92,11 +92,11 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
     return result;
 }
 
-static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) {
+static std::string generate(llama_context * ctx, llama_sampler * smpl, const std::string & prompt, bool stream) {
     std::string result;
 
-    const llama_model * mdl = llama_get_model(ctx);
-    llama_token eos_token = llama_token_eos(mdl);
+    const llama_model * model = llama_get_model(ctx);
+    llama_token eos_token = llama_token_eos(model);
 
     llama_kv_cache_clear(ctx);
     llama_set_embeddings(ctx, false);
@@ -104,28 +104,25 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
 
     llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
 
-    std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
+    std::vector<llama_token> inputs = llama_tokenize(model, prompt, false, true);
     int32_t i_current_token = 0;
 
     while (true) {
         llama_batch_clear(bat);
-        auto n_inputs = (int32_t)inputs.size();
-        for (int32_t i = 0; i < n_inputs; i++) {
-            llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
+        {
+            const int32_t n_inputs = inputs.size();
+
+            for (int32_t i = 0; i < n_inputs; i++) {
+                llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
+            }
         }
         inputs.clear();
 
         llama_decode(ctx, bat);
-        auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
 
-        auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
-        auto n_candidates = (int32_t)candidates.size();
-        for (int32_t token = 0; token < n_candidates; token++) {
-            candidates[token] = llama_token_data{ token, logits[token], 0.0f };
-        }
-        auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
+        llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
+        llama_sampler_accept(smpl, token);
 
-        llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
         if (token == eos_token) {
             break;
         }
@@ -167,10 +164,18 @@ int main(int argc, char * argv[]) {
 
     llama_backend_init();
 
-    llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
+    llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
 
     // create generation context
-    llama_context * ctx = llama_new_context_with_model(mdl, cparams);
+    llama_context * ctx = llama_new_context_with_model(model, cparams);
+
+    auto sparams = llama_sampler_chain_default_params();
+
+    sparams.no_perf = false;
+
+    llama_sampler * smpl = llama_sampler_chain_init(sparams);
+
+    llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
 
     // ### Embedding/Representation ###
     // samples taken from: https://github.com/ContextualAI/gritlm#basic
@@ -191,7 +196,7 @@ int main(int argc, char * argv[]) {
         const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
         const std::vector<std::vector<float>> q_rep = encode(ctx, queries,   gritlm_instruction(instruction));
 
-        const int n_embd = llama_n_embd(mdl);
+        const int n_embd = llama_n_embd(model);
 
         const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
         const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
@@ -208,11 +213,12 @@ int main(int argc, char * argv[]) {
     // GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
     {
         const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
-        std::string response = generate(ctx, prompt, true);
+        std::string response = generate(ctx, smpl, prompt, true);
     }
 
+    llama_sampler_free(smpl);
     llama_free(ctx);
-    llama_free_model(mdl);
+    llama_free_model(model);
     llama_backend_free();
 
     return 0;
index 83b85d72b043abe438e08d81769dd6305517da8a..107f8c8859dcf885f09ba1dd82822a6e4f31cc02 100644 (file)
@@ -638,7 +638,8 @@ int main(int argc, char ** argv) {
 
     g_collector.save_imatrix();
 
-    llama_print_timings(ctx);
+    LOG_TEE("\n");
+    llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
 
     llama_free(ctx);
     llama_free_model(model);
index 05700c1d591d9f232dee34f22f4696f4aa40b607..1ebc0b324bc8216983c81ee404cb7c8689dfc0cb 100644 (file)
@@ -2,7 +2,6 @@
 
 #include "console.h"
 #include "llama.h"
-#include "grammar-parser.h"
 
 #include <cassert>
 #include <cinttypes>
@@ -34,6 +33,7 @@
 
 static llama_context           ** g_ctx;
 static llama_model             ** g_model;
+static gpt_sampler             ** g_smpl;
 static gpt_params               * g_params;
 static std::vector<llama_token> * g_input_tokens;
 static std::ostringstream       * g_output_ss;
@@ -81,7 +81,7 @@ static void write_logfile(
     yaml_dump_string_multiline(logfile, "output", output.c_str());
     yaml_dump_vector_int(logfile, "output_tokens", output_tokens);
 
-    llama_dump_timing_info_yaml(logfile, ctx);
+    llama_perf_dump_yaml(logfile, ctx);
     fclose(logfile);
 }
 
@@ -93,7 +93,7 @@ static void sigint_handler(int signo) {
         } else {
             console::cleanup();
             printf("\n");
-            llama_print_timings(*g_ctx);
+            gpt_perf_print(*g_ctx, *g_smpl);
             write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
             _exit(130);
         }
@@ -103,7 +103,6 @@ static void sigint_handler(int signo) {
 
 int main(int argc, char ** argv) {
     gpt_params params;
-    llama_sampling_params & sparams = params.sparams;
     g_params = &params;
 
     if (!gpt_params_parse(argc, argv, params)) {
@@ -111,6 +110,8 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
+    auto & sparams = params.sparams;
+
 #ifndef LOG_DISABLE_LOGS
     log_set_target(log_filename_generator("infill", "log"));
     LOG_TEE("Log start\n");
@@ -156,26 +157,21 @@ int main(int argc, char ** argv) {
         LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
     }
 
-    LOG_TEE("%s: build = %d (%s)\n",      __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
-    LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
-
-    if (params.seed == LLAMA_DEFAULT_SEED) {
-        params.seed = time(NULL);
-    }
-
-    LOG_TEE("%s: seed  = %u\n", __func__, params.seed);
+    print_build_info();
 
-    std::mt19937 rng(params.seed);
+    LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
 
     LOG("%s: llama backend init\n", __func__);
     llama_backend_init();
     llama_numa_init(params.numa);
 
-    llama_model * model;
-    llama_context * ctx;
+    llama_model * model = nullptr;
+    llama_context * ctx = nullptr;
+    gpt_sampler  * smpl = nullptr;
 
     g_model = &model;
     g_ctx = &ctx;
+    g_smpl = &smpl;
 
     // load the model and apply lora adapter, if any
     LOG("%s: load the model and apply lora adapter, if any\n", __func__);
@@ -305,7 +301,7 @@ int main(int argc, char ** argv) {
             LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
         }
     }
-    LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
+    LOG_TEE("sampling: \n%s\n", sparams.print().c_str());
     LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
     LOG_TEE("\n\n");
 
@@ -349,7 +345,7 @@ int main(int argc, char ** argv) {
 
     std::vector<llama_token> embd;
 
-    struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
+    smpl = gpt_sampler_init(model, sparams);
 
     while (n_remain != 0 || params.interactive) {
         // predict
@@ -421,11 +417,11 @@ int main(int argc, char ** argv) {
         embd.clear();
 
         if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
-            const llama_token id = llama_sampling_sample(ctx_sampling, ctx, nullptr);
+            const llama_token id = gpt_sampler_sample(smpl, ctx, -1);
 
-            llama_sampling_accept(ctx_sampling, ctx, id, true);
+            gpt_sampler_accept(smpl, id, true);
 
-            LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
+            // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str());
 
             embd.push_back(id);
 
@@ -444,7 +440,7 @@ int main(int argc, char ** argv) {
 
                 // push the prompt in the sampling context in order to apply repetition penalties later
                 // for the prompt, we don't apply grammar rules
-                llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
+                gpt_sampler_accept(smpl, embd_inp[n_consumed], false);
 
                 ++n_consumed;
                 if ((int) embd.size() >= params.n_batch) {
@@ -476,7 +472,7 @@ int main(int argc, char ** argv) {
         // if not currently processing queued inputs;
         if ((int) embd_inp.size() <= n_consumed) {
             // deal with eot token in infill mode
-            if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){
+            if ((gpt_sampler_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){
                 if (is_interacting && !params.interactive_first) {
                     // print an eot token
                     printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
@@ -542,7 +538,7 @@ int main(int argc, char ** argv) {
                 is_interacting = false;
             }
             // deal with end of generation tokens in interactive mode
-            else if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
+            else if (llama_token_is_eog(model, gpt_sampler_last(smpl))) {
                 LOG("found EOS token\n");
 
                 if (params.interactive) {
@@ -615,7 +611,7 @@ int main(int argc, char ** argv) {
 
             if (n_past > 0) {
                 if (is_interacting) {
-                    llama_sampling_reset(ctx_sampling);
+                    gpt_sampler_reset(smpl);
                 }
                 is_interacting = false;
             }
@@ -638,13 +634,14 @@ int main(int argc, char ** argv) {
         fflush(stdout);
     }
 
-    llama_print_timings(ctx);
+    LOG_TEE("\n");
+    gpt_perf_print(ctx, smpl);
     write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
 
     llama_free(ctx);
     llama_free_model(model);
 
-    llama_sampling_free(ctx_sampling);
+    gpt_sampler_free(smpl);
     llama_backend_free();
 
 #ifndef LOG_DISABLE_LOGS
index fe1802b51bdf6542315d38f80fb47f837ce1fa77..d7db5af722a601783194c07647debfbbc5ba74e2 100644 (file)
@@ -1630,7 +1630,7 @@ int main(int argc, char ** argv) {
             fflush(p_err->fout);
         }
 
-        llama_print_timings(ctx);
+        llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
 
         llama_free(ctx);
 
index 2aafe23167557f66e16c82fed9d86ca8494288cd..9217937512d75f731057e966ce2ffb1cbad4e4ae 100644 (file)
@@ -120,8 +120,8 @@ Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmo
     LOGi("Using %d threads", n_threads);
 
     llama_context_params ctx_params = llama_context_default_params();
-    ctx_params.seed  = 1234;
-    ctx_params.n_ctx = 2048;
+
+    ctx_params.n_ctx           = 2048;
     ctx_params.n_threads       = n_threads;
     ctx_params.n_threads_batch = n_threads;
 
@@ -380,11 +380,13 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
         JNIEnv * env,
         jobject,
         jlong context_pointer,
+        jlong sampling_pointer,
         jlong batch_pointer,
         jint n_len,
         jobject intvar_ncur
 ) {
     const auto context = reinterpret_cast<llama_context *>(context_pointer);
+    const auto sampling = reinterpret_cast<llama_sampler *>(sampling_pointer);
     const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
     const auto model = llama_get_model(context);
 
@@ -392,20 +394,10 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
     if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
     if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
 
-    auto n_vocab = llama_n_vocab(model);
-    auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
-
-    std::vector<llama_token_data> candidates;
-    candidates.reserve(n_vocab);
-
-    for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-        candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
-    }
-
-    llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-
     // sample the most likely token
-    const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
+    const auto new_token_id = llama_sampler_sample(sampling, context, batch->n_tokens - 1);
+
+    llama_sampler_accept(sampling, new_token_id);
 
     const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
     if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
index 48b7840ae49c3ade861aed464cdcbc8bce8996cb..92f61fe83081d6bcc6f9eaabb1fc4d1d5291989d 100644 (file)
@@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
 actor LlamaContext {
     private var model: OpaquePointer
     private var context: OpaquePointer
+    private var sampling: UnsafeMutablePointer<llama_sampler>
     private var batch: llama_batch
     private var tokens_list: [llama_token]
     var is_done: Bool = false
@@ -42,9 +43,15 @@ actor LlamaContext {
         self.tokens_list = []
         self.batch = llama_batch_init(512, 0, 1)
         self.temporary_invalid_cchars = []
+        let sparams = llama_sampler_chain_default_params()
+        self.sampling = llama_sampler_chain_init(sparams)
+        llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
+        llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax())
+        llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
     }
 
     deinit {
+        llama_sampler_free(sampling)
         llama_batch_free(batch)
         llama_free(context)
         llama_free_model(model)
@@ -69,7 +76,6 @@ actor LlamaContext {
         print("Using \(n_threads) threads")
 
         var ctx_params = llama_context_default_params()
-        ctx_params.seed  = 1234
         ctx_params.n_ctx = 2048
         ctx_params.n_threads       = Int32(n_threads)
         ctx_params.n_threads_batch = Int32(n_threads)
@@ -144,20 +150,9 @@ actor LlamaContext {
     func completion_loop() -> String {
         var new_token_id: llama_token = 0
 
-        let n_vocab = llama_n_vocab(model)
-        let logits = llama_get_logits_ith(context, batch.n_tokens - 1)
+        new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
 
-        var candidates = Array<llama_token_data>()
-        candidates.reserveCapacity(Int(n_vocab))
-
-        for token_id in 0..<n_vocab {
-            candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
-        }
-        candidates.withUnsafeMutableBufferPointer() { buffer in
-            var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false)
-
-            new_token_id = llama_sample_token_greedy(context, &candidates_p)
-        }
+        llama_sampler_accept(sampling, new_token_id)
 
         if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
             print("\n")
index 86b39f20eea6e10828f7d3f59a2a2b207dba6711..4d7ccc91fc4b4222d5ac20e3e4a68b104c7a21b2 100644 (file)
@@ -40,11 +40,11 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n
     return true;
 }
 
-static const char * sample(struct llama_sampling_context * ctx_sampling,
+static const char * sample(struct gpt_sampler * smpl,
                            struct llama_context * ctx_llama,
                            int * n_past) {
-    const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
-    llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
+    const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1);
+    gpt_sampler_accept(smpl, id, true);
     static std::string ret;
     if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
         ret = "</s>";
@@ -191,15 +191,15 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
 
     LOG_TEE("\n");
 
-    struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
-    if (!ctx_sampling) {
+    struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams);
+    if (!smpl) {
         fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
         exit(1);
     }
 
     std::string response = "";
     for (int i = 0; i < max_tgt_len; i++) {
-        const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
+        const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
         response += tmp;
         if (strcmp(tmp, "</s>") == 0) break;
         if (strstr(tmp, "###")) break; // Yi-VL behavior
@@ -211,7 +211,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
         fflush(stdout);
     }
 
-    llama_sampling_free(ctx_sampling);
+    gpt_sampler_free(smpl);
     printf("\n");
 }
 
@@ -310,7 +310,7 @@ int main(int argc, char ** argv) {
         // process the prompt
         process_prompt(ctx_llava, image_embed, &params, params.prompt);
 
-        llama_print_timings(ctx_llava->ctx_llama);
+        llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT);
         llava_image_embed_free(image_embed);
         ctx_llava->model = NULL;
         llava_free(ctx_llava);
@@ -327,7 +327,7 @@ int main(int argc, char ** argv) {
             // process the prompt
             process_prompt(ctx_llava, image_embed, &params, params.prompt);
 
-            llama_print_timings(ctx_llava->ctx_llama);
+            llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT);
             llava_image_embed_free(image_embed);
             ctx_llava->model = NULL;
             llava_free(ctx_llava);
index f500ea5b944f47e3931a1163dad02cda3907243c..237da9429ecc6a19f78012960be99799f1887557 100644 (file)
@@ -163,11 +163,11 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
     LOG_TEE("%s: image token past: %d\n", __func__, n_past);
 }
 
-static const char * sample(struct llama_sampling_context * ctx_sampling,
+static const char * sample(struct gpt_sampler * smpl,
                            struct llama_context * ctx_llama,
                            int * n_past) {
-    const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
-    llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
+    const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1);
+    gpt_sampler_accept(smpl, id, true);
     static std::string ret;
     if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
         ret = "</s>";
@@ -214,7 +214,7 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri
     return ctx_llava;
 }
 
-static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
+static struct gpt_sampler * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
     std::string user_prompt = prompt;
     int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
     if (!is_first) {
@@ -238,13 +238,13 @@ static struct llama_sampling_context * llama_init(struct llava_context * ctx_lla
 
     LOG_TEE("\n");
 
-    struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
-    return ctx_sampling;
+    struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams);
+    return smpl;
 }
 
-static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling_context * ctx_sampling, int &n_past){
+static const char * llama_loop(struct llava_context * ctx_llava,struct gpt_sampler * smpl, int &n_past){
 
-    const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
+    const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
     return tmp;
 }
 
@@ -278,12 +278,12 @@ int main(int argc, char ** argv) {
         if (!params.prompt.empty()) {
             LOG_TEE("<user>%s\n", params.prompt.c_str());
             LOG_TEE("<assistant>");
-            auto ctx_sampling = llama_init(ctx_llava, &params, params.prompt.c_str(), n_past, true);
+            auto smpl = llama_init(ctx_llava, &params, params.prompt.c_str(), n_past, true);
             const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
             std::string response = "";
             bool have_tmp = false;
             for (int i = 0; i < max_tgt_len; i++) {
-                auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past);
+                auto tmp = llama_loop(ctx_llava, smpl, n_past);
                 response += tmp;
                 if (strcmp(tmp, "</s>") == 0){
                     if(!have_tmp)continue;
@@ -296,18 +296,18 @@ int main(int argc, char ** argv) {
 
                 fflush(stdout);
             }
-            llama_sampling_free(ctx_sampling);
+            gpt_sampler_free(smpl);
         }else {
             while (true) {
                 LOG_TEE("<user>");
                 std::string prompt;
                 std::getline(std::cin, prompt);
                 LOG_TEE("<assistant>");
-                auto ctx_sampling = llama_init(ctx_llava, &params, prompt, n_past, true);
+                auto smpl = llama_init(ctx_llava, &params, prompt, n_past, true);
                 const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
                 std::string response = "";
                 for (int i = 0; i < max_tgt_len; i++) {
-                    auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past);
+                    auto tmp = llama_loop(ctx_llava, smpl, n_past);
                     response += tmp;
                     if (strcmp(tmp, "</s>") == 0) break;
                     if (strstr(tmp, "###")) break; // Yi-VL behavior
@@ -315,11 +315,11 @@ int main(int argc, char ** argv) {
                     if (strstr(response.c_str(), "<user>")) break; // minicpm-v
                     fflush(stdout);
                 }
-                llama_sampling_free(ctx_sampling);
+                gpt_sampler_free(smpl);
             }
         }
         printf("\n");
-        llama_print_timings(ctx_llava->ctx_llama);
+        llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT);
 
         ctx_llava->model = NULL;
         llava_free(ctx_llava);
index 81cf1629c5b6ae5c46eebd766c09d0317693604d..c2e931c651008c0f10fad5159ca244beb8d0e5ac 100644 (file)
@@ -1,7 +1,6 @@
 #include "common.h"
 #include "llama.h"
 
-#include <cmath>
 #include <cstdio>
 #include <string>
 #include <vector>
@@ -118,7 +117,7 @@ int main(int argc, char ** argv) {
     llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
 
     // target model sampling context
-    struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
+    struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams);
 
     // verification n-grams
     std::vector<ngram_data> ngrams_cur(G);
@@ -159,9 +158,9 @@ int main(int argc, char ** argv) {
 
     // sample first token
     {
-        id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
+        id = gpt_sampler_sample(smpl, ctx, 0);
 
-        llama_sampling_accept(ctx_sampling, ctx, id, true);
+        gpt_sampler_accept(smpl, id, true);
 
         {
             const std::string token_str = llama_token_to_piece(ctx, id);
@@ -284,9 +283,9 @@ int main(int argc, char ** argv) {
             }
 
             // sample the next token
-            id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
+            id = gpt_sampler_sample(smpl, ctx, i_batch);
 
-            llama_sampling_accept(ctx_sampling, ctx, id, true);
+            gpt_sampler_accept(smpl, id, true);
 
             // print
             {
@@ -361,7 +360,7 @@ int main(int argc, char ** argv) {
                 if (v == 0) {
                     // sample from the last level
                     for (int i = 0; i < W; i++) {
-                        tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
+                        tokens_j[N - 2][i] = gpt_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
                     }
                 } else {
                     for (int i = 0; i < W; i++) {
@@ -468,10 +467,12 @@ int main(int argc, char ** argv) {
     LOG_TEE("n_predict = %d\n", n_predict);
     LOG_TEE("n_accept  = %d\n", n_accept);
 
-    llama_print_timings(ctx);
+    LOG_TEE("\n");
+    gpt_perf_print(ctx, smpl);
+
+    gpt_sampler_free(smpl);
 
     llama_kv_cache_view_free(&kvc_view);
-    llama_sampling_free(ctx_sampling);
 
     llama_batch_free(batch);
 
index d53a9828c2ea23918c01b15b507579e5c0647213..071400b7e7f7ea62c2a0d4dfef2ecf62bca1d220 100644 (file)
@@ -3,13 +3,11 @@
 #include "common.h"
 #include "ngram-cache.h"
 
-#include <cmath>
 #include <cstdint>
 #include <cstdio>
 #include <fstream>
 #include <string>
 #include <vector>
-#include <unordered_map>
 
 int main(int argc, char ** argv){
     gpt_params params;
@@ -106,7 +104,7 @@ int main(int argc, char ** argv){
 
     bool has_eos = false;
 
-    struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
+    struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams);
 
     std::vector<llama_token> draft;
 
@@ -130,9 +128,9 @@ int main(int argc, char ** argv){
         int i_dft = 0;
         while (true) {
             // sample from the target model
-            llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
+            llama_token id = gpt_sampler_sample(smpl, ctx, i_dft);
 
-            llama_sampling_accept(ctx_sampling, ctx, id, true);
+            gpt_sampler_accept(smpl, id, true);
 
             const std::string token_str = llama_token_to_piece(ctx, id);
 
@@ -240,10 +238,12 @@ int main(int argc, char ** argv){
     LOG_TEE("n_accept     = %d\n", n_accept);
     LOG_TEE("accept       = %.3f%%\n", 100.0f * n_accept / n_drafted);
 
-    LOG_TEE("\ntarget:\n");
-    llama_print_timings(ctx);
+    LOG_TEE("\ntarget:\n\n");
+    llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN);
+    llama_perf_print(ctx,  LLAMA_PERF_TYPE_CONTEXT);
+
+    gpt_sampler_free(smpl);
 
-    llama_sampling_free(ctx_sampling);
     llama_batch_free(batch_tgt);
 
     llama_free(ctx);
index c55efbb66d7c12c5195ff2a70ee0dc12f61e063d..42058d41de35d0de609bfa9faa8d3baffb7ff03f 100644 (file)
@@ -33,6 +33,7 @@
 
 static llama_context           ** g_ctx;
 static llama_model             ** g_model;
+static gpt_sampler             ** g_smpl;
 static gpt_params               * g_params;
 static std::vector<llama_token> * g_input_tokens;
 static std::ostringstream       * g_output_ss;
@@ -92,7 +93,7 @@ static void write_logfile(
     yaml_dump_string_multiline(logfile, "output", output.c_str());
     yaml_dump_vector_int(logfile, "output_tokens", output_tokens);
 
-    llama_dump_timing_info_yaml(logfile, ctx);
+    llama_perf_dump_yaml(logfile, ctx);
     fclose(logfile);
 }
 
@@ -105,7 +106,7 @@ static void sigint_handler(int signo) {
         } else {
             console::cleanup();
             printf("\n");
-            llama_print_timings(*g_ctx);
+            gpt_perf_print(*g_ctx, *g_smpl);
             write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
             _exit(130);
         }
@@ -121,8 +122,7 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
 
 static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
     llama_chat_msg new_msg{role, content};
-    auto formatted = llama_chat_format_single(
-        model, g_params->chat_template, chat_msgs, new_msg, role == "user");
+    auto formatted = llama_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
     chat_msgs.push_back({role, content});
     LOG("formatted: %s\n", formatted.c_str());
     return formatted;
@@ -137,7 +137,7 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    llama_sampling_params & sparams = params.sparams;
+    auto & sparams = params.sparams;
 
 #ifndef LOG_DISABLE_LOGS
     log_set_target(log_filename_generator("main", "log"));
@@ -183,27 +183,23 @@ int main(int argc, char ** argv) {
         LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
     }
 
-    LOG_TEE("%s: build = %d (%s)\n",      __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
-    LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
+    print_build_info();
 
-    if (params.seed == LLAMA_DEFAULT_SEED) {
-        params.seed = time(NULL);
-    }
-
-    LOG_TEE("%s: seed  = %u\n", __func__, params.seed);
-
-    std::mt19937 rng(params.seed);
+    LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
 
     LOG("%s: llama backend init\n", __func__);
     llama_backend_init();
     llama_numa_init(params.numa);
 
-    llama_model * model;
-    llama_context * ctx;
-    llama_context * ctx_guidance = NULL;
+    llama_model * model = nullptr;
+    llama_context * ctx = nullptr;
+    gpt_sampler * smpl = nullptr;
+
     std::vector<llama_chat_msg> chat_msgs;
+
     g_model = &model;
     g_ctx = &ctx;
+    g_smpl = &smpl;
 
     // load the model and apply lora adapter, if any
     LOG("%s: load the model and apply lora adapter, if any\n", __func__);
@@ -211,10 +207,6 @@ int main(int argc, char ** argv) {
 
     model = llama_init.model;
     ctx = llama_init.context;
-    if (sparams.cfg_scale > 1.f) {
-        struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
-        ctx_guidance = llama_new_context_with_model(model, lparams);
-    }
 
     if (model == NULL) {
         LOG_TEE("%s: error: unable to load model\n", __func__);
@@ -251,9 +243,6 @@ int main(int argc, char ** argv) {
     }
 
     llama_attach_threadpool(ctx, threadpool, threadpool_batch);
-    if (ctx_guidance) {
-        llama_attach_threadpool(ctx_guidance, threadpool, threadpool_batch);
-    }
 
     const int n_ctx_train = llama_n_ctx_train(model);
     const int n_ctx = llama_n_ctx(ctx);
@@ -337,24 +326,6 @@ int main(int argc, char ** argv) {
     }
 
     // Tokenize negative prompt
-    std::vector<llama_token> guidance_inp;
-    int guidance_offset = 0;
-    int original_prompt_len = 0;
-    if (ctx_guidance) {
-        LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
-
-        guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true, true);
-        LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
-
-        std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
-        LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
-
-        original_prompt_len = original_inp.size();
-        guidance_offset = (int)guidance_inp.size() - original_prompt_len;
-        LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
-        LOG("guidance_offset:     %s", log_tostr(guidance_offset));
-    }
-
     if ((int) embd_inp.size() > n_ctx - 4) {
         LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
         return 1;
@@ -421,15 +392,6 @@ int main(int argc, char ** argv) {
             LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
         }
 
-        if (ctx_guidance) {
-            LOG_TEE("\n");
-            LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str());
-            LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
-            for (int i = 0; i < (int) guidance_inp.size(); i++) {
-                LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str());
-            }
-        }
-
         if (params.n_keep > add_bos) {
             LOG_TEE("%s: static prompt based on n_keep: '", __func__);
             for (int i = 0; i < params.n_keep; i++) {
@@ -495,8 +457,15 @@ int main(int argc, char ** argv) {
             }
         }
     }
-    LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
-    LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
+
+    smpl = gpt_sampler_init(model, sparams);
+    if (!smpl) {
+        fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
+        exit(1);
+    }
+
+    LOG_TEE("sampling params: \n%s\n", sparams.print().c_str());
+    LOG_TEE(" sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str());
     LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
 
     // group-attention state
@@ -543,7 +512,6 @@ int main(int argc, char ** argv) {
     int n_remain           = params.n_predict;
     int n_consumed         = 0;
     int n_session_consumed = 0;
-    int n_past_guidance    = 0;
 
     std::vector<int>   input_tokens;  g_input_tokens  = &input_tokens;
     std::vector<int>   output_tokens; g_output_tokens = &output_tokens;
@@ -555,7 +523,6 @@ int main(int argc, char ** argv) {
     display = params.display_prompt;
 
     std::vector<llama_token> embd;
-    std::vector<llama_token> embd_guidance;
 
     // tokenized antiprompts
     std::vector<std::vector<llama_token>> antiprompt_ids;
@@ -565,12 +532,6 @@ int main(int argc, char ** argv) {
         antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
     }
 
-    struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
-    if (!ctx_sampling) {
-        fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
-        exit(1);
-    }
-
     if (llama_model_has_encoder(model)) {
         int enc_input_size = embd_inp.size();
         llama_token * enc_input_buf = embd_inp.data();
@@ -612,7 +573,7 @@ int main(int argc, char ** argv) {
                 // if we run out of context:
                 // - take the n_keep first tokens from the original prompt (via n_past)
                 // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
-                if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) >= n_ctx) {
+                if (n_past + (int) embd.size() >= n_ctx) {
                     if (params.n_predict == -2) {
                         LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
                         break;
@@ -629,11 +590,7 @@ int main(int argc, char ** argv) {
 
                     n_past -= n_discard;
 
-                    if (ctx_guidance) {
-                        n_past_guidance -= n_discard;
-                    }
-
-                    LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
+                    LOG("after swap: n_past = %d\n", n_past);
 
                     LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
 
@@ -686,46 +643,6 @@ int main(int argc, char ** argv) {
                 }
             }
 
-            // evaluate tokens in batches
-            // embd is typically prepared beforehand to fit within a batch, but not always
-            if (ctx_guidance) {
-                int input_size = 0;
-                llama_token * input_buf = NULL;
-
-                if (n_past_guidance < (int) guidance_inp.size()) {
-                    // Guidance context should have the same data with these modifications:
-                    //
-                    // * Replace the initial prompt
-                    // * Shift everything by guidance_offset
-                    embd_guidance = guidance_inp;
-                    if (embd.begin() + original_prompt_len < embd.end()) {
-                        embd_guidance.insert(
-                            embd_guidance.end(),
-                            embd.begin() + original_prompt_len,
-                            embd.end()
-                        );
-                    }
-
-                    input_buf  = embd_guidance.data();
-                    input_size = embd_guidance.size();
-
-                    LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
-                } else {
-                    input_buf  = embd.data();
-                    input_size = embd.size();
-                }
-
-                for (int i = 0; i < input_size; i += params.n_batch) {
-                    int n_eval = std::min(input_size - i, params.n_batch);
-                    if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) {
-                        LOG_TEE("%s : failed to eval\n", __func__);
-                        return 1;
-                    }
-
-                    n_past_guidance += n_eval;
-                }
-            }
-
             for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
                 int n_eval = (int) embd.size() - i;
                 if (n_eval > params.n_batch) {
@@ -755,7 +672,6 @@ int main(int argc, char ** argv) {
         }
 
         embd.clear();
-        embd_guidance.clear();
 
         if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
             // optionally save the session on first sample (for faster prompt loading next time)
@@ -766,11 +682,11 @@ int main(int argc, char ** argv) {
                 LOG("saved session to %s\n", path_session.c_str());
             }
 
-            const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
+            const llama_token id = gpt_sampler_sample(smpl, ctx, -1);
 
-            llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);
+            gpt_sampler_accept(smpl, id, /* apply_grammar= */ true);
 
-            LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
+            // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str());
 
             embd.push_back(id);
 
@@ -789,7 +705,7 @@ int main(int argc, char ** argv) {
 
                 // push the prompt in the sampling context in order to apply repetition penalties later
                 // for the prompt, we don't apply grammar rules
-                llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], /* apply_grammar= */ false);
+                gpt_sampler_accept(smpl, embd_inp[n_consumed], /* apply_grammar= */ false);
 
                 ++n_consumed;
                 if ((int) embd.size() >= params.n_batch) {
@@ -832,7 +748,7 @@ int main(int argc, char ** argv) {
             // check for reverse prompt in the last n_prev tokens
             if (!params.antiprompt.empty()) {
                 const int n_prev = 32;
-                const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev);
+                const std::string last_output = gpt_sampler_prev_str(smpl, ctx, n_prev);
 
                 is_antiprompt = false;
                 // Check if each of the reverse prompts appears at the end of the output.
@@ -854,7 +770,7 @@ int main(int argc, char ** argv) {
                 }
 
                 // check for reverse prompt using special tokens
-                llama_token last_token = llama_sampling_last(ctx_sampling);
+                llama_token last_token = gpt_sampler_last(smpl);
                 for (std::vector<llama_token> ids : antiprompt_ids) {
                     if (ids.size() == 1 && last_token == ids[0]) {
                         if (params.interactive) {
@@ -871,7 +787,7 @@ int main(int argc, char ** argv) {
             }
 
             // deal with end of generation tokens in interactive mode
-            if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
+            if (llama_token_is_eog(model, gpt_sampler_last(smpl))) {
                 LOG("found an EOG token\n");
 
                 if (params.interactive) {
@@ -892,7 +808,7 @@ int main(int argc, char ** argv) {
 
             // if current token is not EOG, we add it to current assistant message
             if (params.conversation) {
-                auto id = llama_sampling_last(ctx_sampling);
+                const auto id = gpt_sampler_last(smpl);
                 assistant_ss << llama_token_to_piece(ctx, id, false);
             }
 
@@ -988,7 +904,7 @@ int main(int argc, char ** argv) {
 
             if (n_past > 0) {
                 if (is_interacting) {
-                    llama_sampling_reset(ctx_sampling);
+                    gpt_sampler_reset(smpl);
                 }
                 is_interacting = false;
             }
@@ -1013,14 +929,15 @@ int main(int argc, char ** argv) {
         llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
     }
 
-    llama_print_timings(ctx);
+    LOG_TEE("\n");
+    gpt_perf_print(ctx, smpl);
     write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
 
-    if (ctx_guidance) { llama_free(ctx_guidance); }
+    gpt_sampler_free(smpl);
+
     llama_free(ctx);
     llama_free_model(model);
 
-    llama_sampling_free(ctx_sampling);
     llama_backend_free();
 
     ggml_threadpool_free(threadpool);
index 621a1c959062265912993c6065e2e67e785da378..c331c0f28dc7eacdba0979c6077ad54b80188f0c 100644 (file)
@@ -50,8 +50,8 @@ static std::vector<std::string> k_prompts = {
 
 struct client {
     ~client() {
-        if (ctx_sampling) {
-            llama_sampling_free(ctx_sampling);
+        if (smpl) {
+            gpt_sampler_free(smpl);
         }
     }
 
@@ -72,7 +72,7 @@ struct client {
     std::string prompt;
     std::string response;
 
-    struct llama_sampling_context * ctx_sampling = nullptr;
+    struct gpt_sampler * smpl = nullptr;
 };
 
 static void print_date_time() {
@@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
     for (size_t i = 0; i < clients.size(); ++i) {
         auto & client = clients[i];
         client.id = i;
-        client.ctx_sampling = llama_sampling_init(params.sparams);
+        client.smpl = gpt_sampler_init(model, params.sparams);
     }
 
     std::vector<llama_token> tokens_system;
@@ -253,7 +253,7 @@ int main(int argc, char ** argv) {
                     client.prompt   = client.input + "\nAssistant:";
                     client.response = "";
 
-                    llama_sampling_reset(client.ctx_sampling);
+                    gpt_sampler_reset(client.smpl);
 
                     // do not prepend BOS because we have a system prompt!
                     std::vector<llama_token> tokens_prompt;
@@ -341,9 +341,9 @@ int main(int argc, char ** argv) {
                 //printf("client %d, seq %d, token %d, pos %d, batch %d\n",
                 //        client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
 
-                const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
+                const llama_token id = gpt_sampler_sample(client.smpl, ctx, client.i_batch - i);
 
-                llama_sampling_accept(client.ctx_sampling, ctx, id, true);
+                gpt_sampler_accept(client.smpl, id, true);
 
                 if (client.n_decoded == 1) {
                     // start measuring generation time after the first token to make sure all concurrent clients
@@ -371,7 +371,7 @@ int main(int argc, char ** argv) {
                     }
 
                     // delete only the generated part of the sequence, i.e. keep the system prompt in the cache
-                    llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
+                    llama_kv_cache_seq_rm(ctx,    client.id + 1, -1, -1);
                     llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
 
                     const auto t_main_end = ggml_time_us();
@@ -413,7 +413,8 @@ int main(int argc, char ** argv) {
 
     LOG_TEE("\n");
 
-    llama_print_timings(ctx);
+    // TODO: print sampling/grammar timings for all clients
+    llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
 
     llama_batch_free(batch);
 
index d03215cd1e0a94e4810868e234996974ab2753f7..ff8d0302f8f0a902adf57ece4497625456a7bb87 100644 (file)
@@ -26,8 +26,6 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    srand(params.seed == LLAMA_DEFAULT_SEED ? time(NULL) : params.seed);
-
     int n_junk = params.n_junk;
     int n_keep = params.n_keep;
     int n_grp  = params.grp_attn_n;
@@ -80,12 +78,17 @@ int main(int argc, char ** argv) {
     GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");
 
     llama_context * ctx = llama_new_context_with_model(model, ctx_params);
-
     if (ctx == NULL) {
         fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
         return 1;
     }
 
+    auto sparams = llama_sampler_chain_default_params();
+
+    llama_sampler * smpl = llama_sampler_chain_init(sparams);
+
+    llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
+
     // tokenize the prompt
     std::vector<llama_token> tokens_list;
     tokens_list = ::llama_tokenize(ctx, params.prompt, true);
@@ -217,20 +220,9 @@ int main(int argc, char ** argv) {
     while (n_cur <= n_len) {
         // sample the next token
         {
-            auto   n_vocab = llama_n_vocab(model);
-            auto * logits  = llama_get_logits_ith(ctx, batch.n_tokens - 1);
+            const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
 
-            std::vector<llama_token_data> candidates;
-            candidates.reserve(n_vocab);
-
-            for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-                candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
-            }
-
-            llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-
-            // sample the most likely token
-            const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
+            llama_sampler_accept(smpl, new_token_id);
 
             // is it an end of generation?
             if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
@@ -267,10 +259,13 @@ int main(int argc, char ** argv) {
     LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
             __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
 
-    llama_print_timings(ctx);
+    LOG_TEE("\n");
+    llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
 
     fprintf(stderr, "\n");
 
+    llama_sampler_free(smpl);
+
     llama_batch_free(batch);
 
     llama_free(ctx);
index 484dd589109c7cbc736a28bcea52be686547ce1a..2ca43f1256765407b407e8c1e111ef300f3c3fb0 100644 (file)
@@ -76,7 +76,7 @@ static void write_logfile(
     fprintf(logfile, "ppl_value: %f\n", results.ppl_value);
     yaml_dump_vector_float(logfile, "probs", results.probs);
 
-    llama_dump_timing_info_yaml(logfile, ctx);
+    llama_perf_dump_yaml(logfile, ctx);
     fclose(logfile);
 }
 
@@ -2007,13 +2007,7 @@ int main(int argc, char ** argv) {
 
     print_build_info();
 
-    if (params.seed == LLAMA_DEFAULT_SEED) {
-        params.seed = time(NULL);
-    }
-
-    fprintf(stderr, "%s: seed  = %u\n", __func__, params.seed);
-
-    std::mt19937 rng(params.seed);
+    LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
 
     llama_backend_init();
     llama_numa_init(params.numa);
@@ -2054,7 +2048,8 @@ int main(int argc, char ** argv) {
         results = perplexity(ctx, params, n_ctx);
     }
 
-    llama_print_timings(ctx);
+    LOG_TEE("\n");
+    llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
     write_logfile(ctx, params, model, results);
 
     llama_free(ctx);
index 68cf8d3595e878eade9304388aea81eedf1da39d..498cbbe3ce1cdacedd603a247c2385d39aa8ec21 100644 (file)
@@ -1,7 +1,7 @@
-#define LLAMA_API_INTERNAL
 #include "common.h"
 #include "ggml.h"
 #include "llama.h"
+#include "llama-impl.h"
 
 #include <algorithm>
 #include <cassert>
@@ -319,8 +319,7 @@ int main(int argc, char ** argv) {
         }
 
         auto cparams = llama_context_default_params();
-        cparams.n_ctx      = 256;
-        cparams.seed       = 1;
+        cparams.n_ctx = 256;
 
         ctx = llama_new_context_with_model(model, cparams);
 
index aab9d81058af93b423d228d953d6353313a2c041..7eb94765041a2c548dfd1b04db7ed50284c45df4 100644 (file)
@@ -293,9 +293,11 @@ int main(int argc, char ** argv) {
         }
     }
 
+    LOG_TEE("\n");
+    llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
+
     // clean up
     llama_batch_free(query_batch);
-    llama_print_timings(ctx);
     llama_free(ctx);
     llama_free_model(model);
     llama_backend_free();
index 3ea7c790d2bf782138124d0ea424f109b64e08b2..133a010e4757aa731d6d8f3f08c7df8a4f46ee5f 100644 (file)
@@ -3,12 +3,12 @@
 
 #include <vector>
 #include <cstdio>
-#include <chrono>
 
 int main(int argc, char ** argv) {
     gpt_params params;
 
     params.prompt = "The quick brown fox";
+    params.sparams.seed = 1234;
 
     if (!gpt_params_parse(argc, argv, params)) {
         gpt_params_print_usage(argc, argv, params);
@@ -38,6 +38,13 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
+    auto sparams = llama_sampler_chain_default_params();
+
+    llama_sampler * smpl = llama_sampler_chain_init(sparams);
+
+    llama_sampler_chain_add(smpl, llama_sampler_init_softmax());
+    llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed));
+
     // tokenize prompt
     auto tokens = llama_tokenize(ctx, params.prompt, true);
 
@@ -64,18 +71,11 @@ int main(int argc, char ** argv) {
     printf("\nfirst run: %s", params.prompt.c_str());
 
     for (auto i = 0; i < params.n_predict; i++) {
-        auto * logits = llama_get_logits(ctx);
-        auto n_vocab = llama_n_vocab(model);
-
-        std::vector<llama_token_data> candidates;
-        candidates.reserve(n_vocab);
-        for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-            candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
-        }
-        llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-        auto next_token = llama_sample_token(ctx, &candidates_p);
+        auto next_token     = llama_sampler_sample(smpl, ctx, -1);
         auto next_token_str = llama_token_to_piece(ctx, next_token);
 
+        llama_sampler_accept(smpl, next_token);
+
         printf("%s", next_token_str.c_str());
         result0 += next_token_str;
 
@@ -96,6 +96,11 @@ int main(int argc, char ** argv) {
     // make new context
     auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
 
+    llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
+
+    llama_sampler_chain_add(smpl2, llama_sampler_init_softmax());
+    llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed));
+
     printf("\nsecond run: %s", params.prompt.c_str());
 
     // load state (rng, logits, embedding and kv_cache) from file
@@ -124,17 +129,11 @@ int main(int argc, char ** argv) {
 
     // second run
     for (auto i = 0; i < params.n_predict; i++) {
-        auto * logits = llama_get_logits(ctx2);
-        auto n_vocab = llama_n_vocab(model);
-        std::vector<llama_token_data> candidates;
-        candidates.reserve(n_vocab);
-        for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-            candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
-        }
-        llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-        auto next_token = llama_sample_token(ctx2, &candidates_p);
+        auto next_token     = llama_sampler_sample(smpl2, ctx2, -1);
         auto next_token_str = llama_token_to_piece(ctx2, next_token);
 
+        llama_sampler_accept(smpl2, next_token);
+
         printf("%s", next_token_str.c_str());
         result1 += next_token_str;
 
@@ -157,7 +156,12 @@ int main(int argc, char ** argv) {
     }
 
     // make new context
-    auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
+    auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
+
+    llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
+
+    llama_sampler_chain_add(smpl3, llama_sampler_init_softmax());
+    llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed));
 
     printf("\nsingle seq run: %s", params.prompt.c_str());
 
@@ -215,17 +219,11 @@ int main(int argc, char ** argv) {
 
     // third run with seq 1 instead of 0
     for (auto i = 0; i < params.n_predict; i++) {
-        auto * logits = llama_get_logits(ctx3);
-        auto n_vocab = llama_n_vocab(model);
-        std::vector<llama_token_data> candidates;
-        candidates.reserve(n_vocab);
-        for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-            candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
-        }
-        llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-        auto next_token = llama_sample_token(ctx3, &candidates_p);
+        auto next_token     = llama_sampler_sample(smpl3, ctx3, -1);
         auto next_token_str = llama_token_to_piece(ctx3, next_token);
 
+        llama_sampler_accept(smpl3, next_token);
+
         printf("%s", next_token_str.c_str());
         result2 += next_token_str;
 
@@ -240,6 +238,10 @@ int main(int argc, char ** argv) {
 
     printf("\n");
 
+    llama_sampler_free(smpl);
+    llama_sampler_free(smpl2);
+    llama_sampler_free(smpl3);
+
     llama_free(ctx3);
     llama_free_model(model);
 
index 805e05b4a51142fa81b922c87274d59cb09efdce..37024dea0055c24529741a9a93d426d4ab8af246 100644 (file)
@@ -470,8 +470,6 @@ node index.js
 
     `frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
 
-    `penalty_prompt`: This will replace the `prompt` for the purpose of the penalty evaluation. Can be either `null`, a string or an array of numbers representing tokens. Default: `null`, which is to use the original `prompt`.
-
     `mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0.
 
     `mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0`
@@ -724,7 +722,6 @@ Example:
             "stopping_word": ""
         },
         "penalize_nl": true,
-        "penalty_prompt_tokens": [],
         "presence_penalty": 0.0,
         "prompt": "Say hello to llama.cpp",
         "repeat_last_n": 64,
@@ -748,8 +745,7 @@ Example:
         "tfs_z": 1.0,
         "top_k": 40,
         "top_p": 0.949999988079071,
-        "typical_p": 1.0,
-        "use_penalty_prompt_tokens": false
+        "typical_p": 1.0
     }
 ]
 ```
index cc65c57ab723c9dc0abaddb6506051e38d7dff01..f45b59983f05b55a69b29158acf64eca277cffad 100644 (file)
@@ -3,7 +3,6 @@
 #include "common.h"
 #include "json-schema-to-grammar.h"
 #include "llama.h"
-#include "grammar-parser.h"
 
 // Change JSON_ASSERT from assert() to GGML_ASSERT:
 #define JSON_ASSERT GGML_ASSERT
@@ -169,11 +168,13 @@ struct server_slot {
     std::string stopping_word;
 
     // sampling
-    llama_token sampled;
-    struct llama_sampling_params sparams;
-    llama_sampling_context * ctx_sampling = nullptr;
     json json_schema;
 
+    struct gpt_sampler_params sparams;
+    struct gpt_sampler * smpl = nullptr;
+
+    llama_token sampled;
+
     int32_t ga_i = 0;   // group-attention state
     int32_t ga_n = 1;   // group-attention factor
     int32_t ga_w = 512; // group-attention width
@@ -651,8 +652,8 @@ struct server_context {
 
         // Clear any sampling context
         for (server_slot & slot : slots) {
-            if (slot.ctx_sampling != nullptr) {
-                llama_sampling_free(slot.ctx_sampling);
+            if (slot.smpl != nullptr) {
+                gpt_sampler_free(slot.smpl);
             }
         }
 
@@ -883,8 +884,8 @@ struct server_context {
     bool launch_slot_with_task(server_slot & slot, const server_task & task) {
         slot_params default_params;
         // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
-        llama_sampling_params default_sparams = params.sparams;
-        auto & data = task.data;
+        auto default_sparams = params.sparams;
+        const auto & data = task.data;
 
         if (data.count("__oaicompat") != 0) {
             slot.oaicompat = true;
@@ -901,7 +902,7 @@ struct server_context {
         slot.sparams.top_p             = json_value(data, "top_p",             default_sparams.top_p);
         slot.sparams.min_p             = json_value(data, "min_p",             default_sparams.min_p);
         slot.sparams.tfs_z             = json_value(data, "tfs_z",             default_sparams.tfs_z);
-        slot.sparams.typical_p         = json_value(data, "typical_p",         default_sparams.typical_p);
+        slot.sparams.typ_p             = json_value(data, "typical_p",         default_sparams.typ_p);
         slot.sparams.temp              = json_value(data, "temperature",       default_sparams.temp);
         slot.sparams.dynatemp_range    = json_value(data, "dynatemp_range",    default_sparams.dynatemp_range);
         slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
@@ -923,7 +924,8 @@ struct server_context {
         if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
             send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
             return false;
-        } else if (data.contains("json_schema") && !data.contains("grammar")) {
+        }
+        if (data.contains("json_schema") && !data.contains("grammar")) {
             try {
                 auto schema                = json_value(data, "json_schema", json::object());
                 slot.sparams.grammar       = json_schema_to_grammar(schema);
@@ -973,56 +975,11 @@ struct server_context {
             }
         }
 
-        // penalize user-provided tokens
-        {
-            slot.sparams.penalty_prompt_tokens.clear();
-            slot.sparams.use_penalty_prompt_tokens = false;
-
-            const auto & penalty_prompt = data.find("penalty_prompt");
-
-            if (penalty_prompt != data.end()) {
-                if (penalty_prompt->is_string()) {
-                    const auto penalty_prompt_string = penalty_prompt->get<std::string>();
-                    slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false);
-
-                    if (slot.params.n_predict > 0) {
-                        slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
-                    }
-                    slot.sparams.use_penalty_prompt_tokens = true;
-
-                    LOG_VERBOSE("penalty_prompt_tokens", {
-                        {"id_slot", slot.id},
-                        {"tokens",  slot.sparams.penalty_prompt_tokens},
-                    });
-                }
-                else if (penalty_prompt->is_array()) {
-                    const auto n_tokens = penalty_prompt->size();
-                    slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
-
-                    const int n_vocab = llama_n_vocab(model);
-                    for (const auto & penalty_token : *penalty_prompt) {
-                        if (penalty_token.is_number_integer()) {
-                            const auto tok = penalty_token.get<llama_token>();
-                            if (tok >= 0 && tok < n_vocab) {
-                                slot.sparams.penalty_prompt_tokens.push_back(tok);
-                            }
-                        }
-                    }
-                    slot.sparams.use_penalty_prompt_tokens = true;
-
-                    LOG_VERBOSE("penalty_prompt_tokens", {
-                        {"id_slot", slot.id},
-                        {"tokens",  slot.sparams.penalty_prompt_tokens},
-                    });
-                }
-            }
-        }
-
         {
             slot.sparams.logit_bias.clear();
 
             if (json_value(data, "ignore_eos", false) && has_eos_token) {
-                slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
+                slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
             }
 
             const auto & logit_bias = data.find("logit_bias");
@@ -1043,12 +1000,12 @@ struct server_context {
                         if (el[0].is_number_integer()) {
                             llama_token tok = el[0].get<llama_token>();
                             if (tok >= 0 && tok < n_vocab) {
-                                slot.sparams.logit_bias[tok] = bias;
+                                slot.sparams.logit_bias.push_back({tok, bias});
                             }
                         } else if (el[0].is_string()) {
                             auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
                             for (auto tok : toks) {
-                                slot.sparams.logit_bias[tok] = bias;
+                                slot.sparams.logit_bias.push_back({tok, bias});
                             }
                         }
                     }
@@ -1070,26 +1027,27 @@ struct server_context {
         }
 
         {
-            const auto & samplers_sequence = data.find("samplers");
-            if (samplers_sequence != data.end() && samplers_sequence->is_array()) {
+            const auto & samplers = data.find("samplers");
+            if (samplers != data.end() && samplers->is_array()) {
                 std::vector<std::string> sampler_names;
-                for (const auto & sampler_name : *samplers_sequence) {
-                    if (sampler_name.is_string()) {
-                        sampler_names.emplace_back(sampler_name);
+                for (const auto & name : *samplers) {
+                    if (name.is_string()) {
+                        sampler_names.emplace_back(name);
                     }
                 }
-                slot.sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, false);
+                slot.sparams.samplers = gpt_sampler_types_from_names(sampler_names, false);
             } else {
-                slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
+                slot.sparams.samplers = default_sparams.samplers;
             }
         }
 
         {
-            if (slot.ctx_sampling != nullptr) {
-                llama_sampling_free(slot.ctx_sampling);
+            if (slot.smpl != nullptr) {
+                gpt_sampler_free(slot.smpl);
             }
-            slot.ctx_sampling = llama_sampling_init(slot.sparams);
-            if (slot.ctx_sampling == nullptr) {
+
+            slot.smpl = gpt_sampler_init(model, slot.sparams);
+            if (slot.smpl == nullptr) {
                 // for now, the only error that may happen here is invalid grammar
                 send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
                 return false;
@@ -1178,11 +1136,6 @@ struct server_context {
         slot.generated_text += token_str;
         slot.has_next_token = true;
 
-        if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
-            // we can change penalty_prompt_tokens because it is always created from scratch each request
-            slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
-        }
-
         // check if there is incomplete UTF-8 character at the end
         bool incomplete = false;
         for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
@@ -1300,13 +1253,10 @@ struct server_context {
     }
 
     json get_formated_generation(const server_slot & slot) const {
-        const auto eos_bias   =             slot.sparams.logit_bias.find(llama_token_eos(model));
-        const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
-
-        std::vector<std::string> samplers_sequence;
-        samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
-        for (const auto & sampler_type : slot.sparams.samplers_sequence) {
-            samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type));
+        std::vector<std::string> samplers;
+        samplers.reserve(slot.sparams.samplers.size());
+        for (const auto & sampler : slot.sparams.samplers) {
+            samplers.emplace_back(gpt_sampler_type_to_str(sampler));
         }
 
         return json {
@@ -1321,13 +1271,11 @@ struct server_context {
             {"top_p",                     slot.sparams.top_p},
             {"min_p",                     slot.sparams.min_p},
             {"tfs_z",                     slot.sparams.tfs_z},
-            {"typical_p",                 slot.sparams.typical_p},
+            {"typical_p",                 slot.sparams.typ_p},
             {"repeat_last_n",             slot.sparams.penalty_last_n},
             {"repeat_penalty",            slot.sparams.penalty_repeat},
             {"presence_penalty",          slot.sparams.penalty_present},
             {"frequency_penalty",         slot.sparams.penalty_freq},
-            {"penalty_prompt_tokens",     slot.sparams.penalty_prompt_tokens},
-            {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
             {"mirostat",                  slot.sparams.mirostat},
             {"mirostat_tau",              slot.sparams.mirostat_tau},
             {"mirostat_eta",              slot.sparams.mirostat_eta},
@@ -1336,13 +1284,13 @@ struct server_context {
             {"max_tokens",                slot.params.n_predict}, // User configured n_predict
             {"n_keep",                    slot.params.n_keep},
             {"n_discard",                 slot.params.n_discard},
-            {"ignore_eos",                ignore_eos},
+            {"ignore_eos",                slot.sparams.ignore_eos},
             {"stream",                    slot.params.stream},
-            {"logit_bias",                slot.sparams.logit_bias},
+          //{"logit_bias",                slot.sparams.logit_bias},
             {"n_probs",                   slot.sparams.n_probs},
             {"min_keep",                  slot.sparams.min_keep},
             {"grammar",                   slot.sparams.grammar},
-            {"samplers",                  samplers_sequence}
+            {"samplers",                  samplers},
         };
     }
 
@@ -2136,7 +2084,7 @@ struct server_context {
                                 GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
                             }
 
-                            llama_sampling_reset(slot.ctx_sampling);
+                            gpt_sampler_reset(slot.smpl);
 
                             if (!slot.params.cache_prompt) {
                                 slot.n_past_se = 0;
@@ -2149,7 +2097,7 @@ struct server_context {
 
                                 // push the prompt into the sampling context (do not apply grammar)
                                 for (int i = 0; i < slot.n_past; ++i) {
-                                    llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
+                                    gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
                                 }
                             }
                         }
@@ -2202,7 +2150,7 @@ struct server_context {
                         slot.n_past_se = 0;
                         slot.ga_i = 0;
                         // TODO: is the system prompt ever in the sampling context?
-                        llama_sampling_reset(slot.ctx_sampling);
+                        gpt_sampler_reset(slot.smpl);
                     }
 
                     // remove the non-common part from the cache
@@ -2375,18 +2323,18 @@ struct server_context {
                         slot.release();
                         slot.i_batch = -1;
                         continue; // continue loop of slots
-                    } else {
-                        // prompt evaluated for next-token prediction
-                        slot.state = SLOT_STATE_GENERATING;
                     }
+
+                    // prompt evaluated for next-token prediction
+                    slot.state = SLOT_STATE_GENERATING;
                 } else if (slot.state != SLOT_STATE_GENERATING) {
                     continue; // continue loop of slots
                 }
 
                 completion_token_output result;
-                const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
+                const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
 
-                llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
+                gpt_sampler_accept(slot.smpl, id, true);
 
                 slot.n_decoded += 1;
                 if (slot.n_decoded == 1) {
@@ -2395,34 +2343,15 @@ struct server_context {
                     metrics.on_prompt_eval(slot);
                 }
 
-                llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
                 result.tok = id;
 
-                const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
-                if (n_probs > 0) {
-                    const size_t n_valid = slot.ctx_sampling->n_valid;
+                const auto * cur_p = gpt_sampler_get_candidates(slot.smpl);
 
-                    // Make sure at least n_probs top tokens are at the front of the vector:
-                    if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
-                        llama_sample_top_k(ctx, &cur_p, n_probs, 0);
-                    }
-
-                    if (slot.sparams.temp == 0.0f) {
-                        // With greedy sampling the probabilities have possibly not been calculated.
-                        for (size_t i = 0; i < n_probs; ++i) {
-                            result.probs.push_back({
-                                cur_p.data[i].id,
-                                i == 0 ? 1.0f : 0.0f
-                            });
-                        }
-                    } else {
-                        for (size_t i = 0; i < n_probs; ++i) {
-                            result.probs.push_back({
-                                cur_p.data[i].id,
-                                i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
-                            });
-                        }
-                    }
+                for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
+                    result.probs.push_back({
+                        cur_p->data[i].id,
+                        i >= cur_p->size ? 0.0f : cur_p->data[i].p,
+                    });
                 }
 
                 if (!process_token(result, slot)) {
index 69a92cf7dc0c01ab0010d18420ea097659392733..8a0ad43ad31b8091503b345f055c710f38685aa0 100644 (file)
@@ -55,6 +55,14 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
+    auto sparams = llama_sampler_chain_default_params();
+
+    sparams.no_perf = false;
+
+    llama_sampler * smpl = llama_sampler_chain_init(sparams);
+
+    llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
+
     // tokenize the prompt
 
     std::vector<llama_token> tokens_list;
@@ -110,20 +118,9 @@ int main(int argc, char ** argv) {
     while (n_cur <= n_predict) {
         // sample the next token
         {
-            auto   n_vocab = llama_n_vocab(model);
-            auto * logits  = llama_get_logits_ith(ctx, batch.n_tokens - 1);
-
-            std::vector<llama_token_data> candidates;
-            candidates.reserve(n_vocab);
+            const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
 
-            for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-                candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
-            }
-
-            llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-
-            // sample the most likely token
-            const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
+            llama_sampler_accept(smpl, new_token_id);
 
             // is it an end of generation?
             if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
@@ -160,12 +157,14 @@ int main(int argc, char ** argv) {
     LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
             __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
 
-    llama_print_timings(ctx);
+    LOG_TEE("\n");
+    llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN);
+    llama_perf_print(ctx,  LLAMA_PERF_TYPE_CONTEXT);
 
     fprintf(stderr, "\n");
 
     llama_batch_free(batch);
-
+    llama_sampler_free(smpl);
     llama_free(ctx);
     llama_free_model(model);
 
index 1616edecbbef6d3d8e6f18b5b3b43f8a21eaefb6..55c6bda70e8e1c757051ee8a716b93a3dff2b135 100644 (file)
@@ -21,7 +21,7 @@ struct seq_draft {
     std::vector<llama_token> tokens;
     std::vector<std::vector<llama_token_data>> dists;
 
-    struct llama_sampling_context * ctx_sampling;
+    struct gpt_sampler * smpl = nullptr;
 };
 
 int main(int argc, char ** argv) {
@@ -43,10 +43,7 @@ int main(int argc, char ** argv) {
     // probability threshold for splitting a draft branch (only for n_seq_dft > 1)
     const float p_split  = params.p_split;
 
-    if (params.seed == LLAMA_DEFAULT_SEED) {
-        params.seed = time(NULL);
-    }
-    std::default_random_engine rng(params.seed);
+    std::default_random_engine rng(params.sparams.seed);
     std::uniform_real_distribution<> u_dist;
 
 #ifndef LOG_DISABLE_LOGS
@@ -179,19 +176,17 @@ int main(int argc, char ** argv) {
     // used to determine end of generation
     bool has_eos = false;
 
-    // target model sampling context
-    struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
+    // target model sampling context (reuse the llama_context's sampling instance)
+    struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams);
+
+    struct llama_sampler * softmax = llama_sampler_init_softmax();
 
     // draft sequence data
     std::vector<seq_draft> drafts(n_seq_dft);
 
-    params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
-    if (params.sparams.temp == 0) {
-        params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
-    }
-
     for (int s = 0; s < n_seq_dft; ++s) {
-        drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
+        // allocate gpt_sampler for each draft sequence
+        drafts[s].smpl = gpt_sampler_init(model_dft, params.sparams);
     }
 
     llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
@@ -233,12 +228,12 @@ int main(int argc, char ** argv) {
                 bool accept = false;
                 if (params.sparams.temp > 0) {
                     // stochastic verification
+                    gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
 
-                    llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
-                    llama_sample_softmax(ctx_tgt, &dist_tgt);
-                    float p_tgt = 0, p_dft = 0;
+                    auto & dist_tgt = *gpt_sampler_get_candidates(smpl);
 
-                    // GGML_ASSERT(dist_tgt.size() == dist_dft.size());
+                    float p_tgt = 0.0f;
+                    float p_dft = 0.0f;
 
                     while (active_seqs.size() > 0) {
                         // randomly select a sequence to verify from active sequences
@@ -257,9 +252,13 @@ int main(int argc, char ** argv) {
                             }
                             continue;
                         }
+
                         LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
                         float r = u_dist(rng);
-                        llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true };
+                        llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true };
+
+                        //GGML_ASSERT(dist_tgt.size <= dist_dft.size);
+
                         // acquire the token probabilities assigned by the draft and target models
                         for (size_t i = 0; i < dist_tgt.size; i++) {
                             if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
@@ -278,7 +277,7 @@ int main(int argc, char ** argv) {
                             accept = true;
                             token_id = drafts[s].tokens[i_dft];
                             token_str = llama_token_to_piece(ctx_tgt, token_id);
-                            llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
+                            gpt_sampler_accept(smpl, token_id, true);
 
                             LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
                             break;
@@ -289,7 +288,6 @@ int main(int argc, char ** argv) {
                             // calculate residual probability
                             GGML_ASSERT(dist_tgt.sorted);
                             GGML_ASSERT(dist_dft.sorted);
-                            float sum_probs = 0.0f;
 
                             // sort dist by id
                             std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
@@ -299,10 +297,18 @@ int main(int argc, char ** argv) {
                                 return a.id < b.id;
                             });
 
+                            float sum_probs = 0.0f;
+
                             for (size_t i = 0; i < dist_tgt.size; i++) {
-                                dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
+                                if (i < dist_dft.size) {
+                                    dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
+                                } else {
+                                    dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p);
+                                }
+
                                 sum_probs += dist_tgt.data[i].p;
                             }
+
                             for (size_t i = 0; i < dist_tgt.size; i++) {
                                 dist_tgt.data[i].p /= sum_probs;
                             }
@@ -332,21 +338,29 @@ int main(int argc, char ** argv) {
                         // all drafted tokens were rejected
                         // sample from the target model
                         LOG("all drafted tokens were rejected, sampling from residual distribution\n");
-                        token_id = llama_sample_token(ctx_tgt, &dist_tgt);
-                        llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
+                        std::vector<float> probs(dist_tgt.size);
+                        for (size_t i = 0; i < dist_tgt.size; ++i) {
+                            probs[i] = dist_tgt.data[i].p;
+                        }
+
+                        std::discrete_distribution<> dist(probs.begin(), probs.end());
+
+                        const int idx = dist(rng);
+
+                        token_id = dist_tgt.data[idx].id;
+                        gpt_sampler_accept(smpl, token_id, true);
                         token_str = llama_token_to_piece(ctx_tgt, token_id);
                     }
-
                 } else {
                     // greedy verification
 
                     // sample from the target model
                     LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
-                    token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
+                    token_id = gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
 
-                    llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
+                    gpt_sampler_accept(smpl, token_id, true);
 
-                    //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
+                    //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, smpl->prev).c_str());
 
                     token_str = llama_token_to_piece(ctx_tgt, token_id);
 
@@ -434,7 +448,10 @@ int main(int argc, char ** argv) {
             break;
         }
 
-        llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling);
+        if (drafts[0].smpl) {
+            gpt_sampler_free(drafts[0].smpl);
+        }
+        drafts[0].smpl = gpt_sampler_clone(smpl);
 
         int n_seq_cur  = 1;
         int n_past_cur = n_past_dft;
@@ -463,20 +480,20 @@ int main(int argc, char ** argv) {
                     continue;
                 }
 
-                llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
+                gpt_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
 
-                const auto & cur_p = drafts[s].ctx_sampling->cur;
+                const auto * cur_p = gpt_sampler_get_candidates(drafts[s].smpl);
 
-                for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
+                for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
                     LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
-                            k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
+                            k, s, i, cur_p->data[k].id, cur_p->data[k].p, llama_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
                 }
 
                 std::vector<int> sa(1, s);
 
                 // attempt to split the branch if the probability is high enough
                 for (int f = 1; f < 8; ++f) {
-                    if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
+                    if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) {
                         LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
 
                         llama_kv_cache_seq_rm(ctx_dft,    n_seq_cur, -1, -1);
@@ -503,7 +520,10 @@ int main(int argc, char ** argv) {
                         drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
                         drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
 
-                        llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling);
+                        if (drafts[n_seq_cur].smpl) {
+                            gpt_sampler_free(drafts[n_seq_cur].smpl);
+                        }
+                        drafts[n_seq_cur].smpl = gpt_sampler_clone(drafts[s].smpl);
 
                         sa.push_back(n_seq_cur);
 
@@ -515,15 +535,15 @@ int main(int argc, char ** argv) {
 
                 // add drafted token for each sequence
                 for (int is = 0; is < (int) sa.size(); ++is) {
-                    const llama_token id = cur_p[is].id;
+                    const llama_token id = cur_p->data[is].id;
 
                     const int s = sa[is];
 
-                    llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
+                    gpt_sampler_accept(drafts[s].smpl, id, true);
 
                     drafts[s].tokens.push_back(id);
                     // save cur_p.data into drafts[s].dists
-                    drafts[s].dists.push_back(cur_p);
+                    drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
 
                     // add unique drafted tokens to the target batch
                     drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
@@ -593,17 +613,19 @@ int main(int argc, char ** argv) {
     LOG_TEE("n_accept  = %d\n", n_accept);
     LOG_TEE("accept    = %.3f%%\n", 100.0f * n_accept / n_drafted);
 
-    LOG_TEE("\ndraft:\n");
-    llama_print_timings(ctx_dft);
+    LOG_TEE("\ndraft:\n\n");
+    // TODO: print sampling/grammar timings for all drafts
+    llama_perf_print(ctx_dft, LLAMA_PERF_TYPE_CONTEXT);
 
-    LOG_TEE("\ntarget:\n");
-    llama_print_timings(ctx_tgt);
+    LOG_TEE("\ntarget:\n\n");
+    gpt_perf_print(ctx_tgt, smpl);
 
-    llama_sampling_free(ctx_sampling);
+    gpt_sampler_free(smpl);
     for (int s = 0; s < n_seq_dft; ++s) {
-        llama_sampling_free(drafts[s].ctx_sampling);
+        gpt_sampler_free(drafts[s].smpl);
     }
 
+    llama_sampler_free(softmax);
     llama_batch_free(batch_dft);
 
     llama_free(ctx_tgt);
index a495e866d5a1a735bf828e3c466b29a2f9106cf6..6334fc30d413c424838f6f4e80a545745684282e 100644 (file)
 
 #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
 
+// TODO: use everywhere in the implementation
+#define LLAMA_TOKEN_NULL -1
+
 #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
 #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
 #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
 
 #define LLAMA_SESSION_MAGIC   LLAMA_FILE_MAGIC_GGSN
-#define LLAMA_SESSION_VERSION 8
+#define LLAMA_SESSION_VERSION 9
 
 #define LLAMA_STATE_SEQ_MAGIC   LLAMA_FILE_MAGIC_GGSQ
 #define LLAMA_STATE_SEQ_VERSION 2
@@ -53,8 +56,10 @@ extern "C" {
     // TODO: show sample usage
     //
 
+    // struct llama_vocab; // TODO: add in the future
     struct llama_model;
     struct llama_context;
+    struct llama_sampler;
 
     typedef int32_t llama_pos;
     typedef int32_t llama_token;
@@ -201,6 +206,7 @@ extern "C" {
         LLAMA_SPLIT_MODE_ROW     = 2, // split rows across GPUs
     };
 
+    // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
     typedef struct llama_token_data {
         llama_token id; // token id
         float logit;    // log-odds of the token
@@ -208,8 +214,10 @@ extern "C" {
     } llama_token_data;
 
     typedef struct llama_token_data_array {
+        // TODO: consider SoA
         llama_token_data * data;
         size_t size;
+        int64_t selected; // this is the index in the data array (i.e. not the token id)
         bool sorted;
     } llama_token_data_array;
 
@@ -302,7 +310,6 @@ extern "C" {
     // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
     //       https://github.com/ggerganov/llama.cpp/pull/7544
     struct llama_context_params {
-        uint32_t seed;              // RNG seed, -1 for random
         uint32_t n_ctx;             // text context, 0 = from model
         uint32_t n_batch;           // logical maximum batch size that can be submitted to llama_decode
         uint32_t n_ubatch;          // physical maximum batch size
@@ -330,11 +337,13 @@ extern "C" {
         enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
         enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
 
-        // Keep the booleans together to avoid misalignment during copy-by-value.
+        // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
+        // TODO: move at the end of the struct
         bool logits_all;  // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
         bool embeddings;  // if true, extract embeddings (together with logits)
         bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
         bool flash_attn;  // whether to use flash attention [EXPERIMENTAL]
+      //bool no_perf;     // whether to measure performance timings, TODO: implement
 
         // Abort callback
         // if it returns true, execution of llama_decode() will be aborted
@@ -358,56 +367,14 @@ extern "C" {
         void * kv_overrides;                 // pointer to vector containing overrides
     } llama_model_quantize_params;
 
-    // grammar types
-    struct llama_grammar;
-
-    // grammar element type
-    enum llama_gretype {
-        // end of rule definition
-        LLAMA_GRETYPE_END            = 0,
-
-        // start of alternate definition for rule
-        LLAMA_GRETYPE_ALT            = 1,
-
-        // non-terminal element: reference to rule
-        LLAMA_GRETYPE_RULE_REF       = 2,
-
-        // terminal element: character (code point)
-        LLAMA_GRETYPE_CHAR           = 3,
-
-        // inverse char(s) ([^a], [^a-b] [^abc])
-        LLAMA_GRETYPE_CHAR_NOT       = 4,
-
-        // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
-        // be an inclusive range ([a-z])
-        LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
-
-        // modifies a preceding LLAMA_GRETYPE_CHAR or
-        // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
-        LLAMA_GRETYPE_CHAR_ALT       = 6,
-
-        // any character (.)
-        LLAMA_GRETYPE_CHAR_ANY       = 7,
-    };
+    typedef struct llama_logit_bias {
+        llama_token token;
+        float bias;
+    } llama_logit_bias;
 
-    typedef struct llama_grammar_element {
-        enum llama_gretype type;
-        uint32_t           value; // Unicode code point or rule ID
-    } llama_grammar_element;
-
-    // performance timing information
-    struct llama_timings {
-        double t_start_ms;
-        double t_end_ms;
-        double t_load_ms;
-        double t_sample_ms;
-        double t_p_eval_ms;
-        double t_eval_ms;
-
-        int32_t n_sample;
-        int32_t n_p_eval;
-        int32_t n_eval;
-    };
+    typedef struct llama_sampler_chain_params {
+        bool no_perf; // whether to measure performance timings
+    } llama_sampler_chain_params;
 
     // used in chat template
     typedef struct llama_chat_message {
@@ -419,8 +386,10 @@ extern "C" {
     struct llama_lora_adapter;
 
     // Helpers for getting default parameters
-    LLAMA_API struct llama_model_params llama_model_default_params(void);
-    LLAMA_API struct llama_context_params llama_context_default_params(void);
+    // TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
+    LLAMA_API struct llama_model_params          llama_model_default_params(void);
+    LLAMA_API struct llama_context_params        llama_context_default_params(void);
+    LLAMA_API struct llama_sampler_chain_params  llama_sampler_chain_default_params(void);
     LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
 
     // Initialize the llama + ggml backend
@@ -443,10 +412,11 @@ extern "C" {
 
     LLAMA_API struct llama_model * llama_load_model_from_file(
                              const char * path_model,
-            struct llama_model_params     params);
+              struct llama_model_params   params);
 
     LLAMA_API void llama_free_model(struct llama_model * model);
 
+    // TODO: rename to llama_init_from_model
     LLAMA_API struct llama_context * llama_new_context_with_model(
                      struct llama_model * model,
             struct llama_context_params   params);
@@ -462,23 +432,22 @@ extern "C" {
     LLAMA_API bool llama_supports_mlock      (void);
     LLAMA_API bool llama_supports_gpu_offload(void);
 
-    LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
-
     LLAMA_API uint32_t llama_n_ctx      (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_batch    (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_ubatch   (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_seq_max  (const struct llama_context * ctx);
 
-    LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
-
-    LLAMA_API enum llama_vocab_type   llama_vocab_type  (const struct llama_model * model);
-    LLAMA_API enum llama_rope_type    llama_rope_type   (const struct llama_model * model);
-
     LLAMA_API int32_t llama_n_vocab    (const struct llama_model * model);
     LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
     LLAMA_API int32_t llama_n_embd     (const struct llama_model * model);
     LLAMA_API int32_t llama_n_layer    (const struct llama_model * model);
 
+    LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
+
+    LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
+    LLAMA_API enum llama_vocab_type   llama_vocab_type  (const struct llama_model * model);
+    LLAMA_API enum llama_rope_type    llama_rope_type   (const struct llama_model * model);
+
     // Get the model's RoPE frequency scaling factor
     LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
 
@@ -706,7 +675,7 @@ extern "C" {
     //
 
     // Returns the *actual* size in bytes of the state
-    // (rng, logits, embedding and kv_cache)
+    // (logits, embedding and kv_cache)
     // Only use when saving the state, not when restoring it, otherwise the size may be too small.
     LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
     LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
@@ -1009,121 +978,110 @@ extern "C" {
                                int32_t   length);
 
     //
-    // Grammar
+    // Sampling API
+    //
+    // Sample usage:
+    //
+    //    // prepare the sampling chain at the start
+    //    auto sparams = llama_sampler_chain_default_params();
+    //
+    //    llama_sampler * smpl = llama_sampler_chain_init(sparams);
+    //
+    //    llama_sampler_chain_add(smpl, llama_sampler_init_top_k(50));
+    //    llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1));
+    //    llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.8));
+    //
+    //    // typically, the chain should end with a sampler such as "greedy", "dist" or "mirostat"
+    //    // this sampler will be responsible to select the actual token
+    //    llama_sampler_chain_add(smpl, llama_sampler_init_dist(seed));
+    //
+    //    ...
+    //
+    //    // decoding loop:
+    //    while (...) {
+    //        ...
+    //
+    //        llama_decode(ctx, batch);
+    //
+    //        // sample from the logits of the last token in the batch
+    //        const llama_token id = llama_sampler_sample(smpl, ctx, -1);
+    //
+    //        // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.)
+    //        llama_sampler_accept(smpl, id);
+    //        ...
+    //    }
+    //
+    //    llama_sampler_free(smpl);
+    //
+    // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
+    // TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab
     //
 
-    /// Initialize a llama_grammar.
-    ///
-    /// @param rules The rule elements of the grammar to initialize.
-    /// @param n_rules The number of rules.
-    /// @param start_rule_index The index of the root rule (the starting point of the grammar).
-    /// @return The initialized llama_grammar or nullptr if initialization failed.
-    LLAMA_API struct llama_grammar * llama_grammar_init(
-            const llama_grammar_element ** rules,
-                                 size_t    n_rules,
-                                 size_t    start_rule_index);
-
-    LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
-
-    LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
-
-    /// @details Apply constraints from grammar
-    LLAMA_API void llama_grammar_sample(
-            const struct llama_grammar * grammar,
-            const struct llama_context * ctx,
-                llama_token_data_array * candidates);
-    LLAMA_API DEPRECATED(void llama_sample_grammar(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-      const struct llama_grammar * grammar),
-        "use llama_grammar_sample instead");
+    typedef void * llama_sampler_context_t;
 
-    /// @details Accepts the sampled token into the grammar
-    LLAMA_API void llama_grammar_accept_token(
-            struct llama_grammar * grammar,
-            struct llama_context * ctx,
-                     llama_token   token);
+    // user code can implement the interface below in order to create custom llama_sampler
+    struct llama_sampler_i {
+        const char *           (*name)  (const struct llama_sampler * smpl);                                 // can be NULL
+        void                   (*accept)(      struct llama_sampler * smpl, llama_token token);              // can be NULL
+        void                   (*apply) (      struct llama_sampler * smpl, llama_token_data_array * cur_p); // required
+        void                   (*reset) (      struct llama_sampler * smpl);                                 // can be NULL
+        struct llama_sampler * (*clone) (const struct llama_sampler * smpl);                                 // can be NULL if ctx is NULL
+        void                   (*free)  (      struct llama_sampler * smpl);                                 // can be NULL if ctx is NULL
 
-    //
-    // Sampling functions
-    //
+        // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
+        //void (*apply_ggml) (struct llama_sampler * smpl, ...);
+    };
 
-    // Sets the current rng seed.
-    LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
+    struct llama_sampler {
+        struct llama_sampler_i  * iface;
+        llama_sampler_context_t   ctx;
+    };
 
-    /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
-    /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
-    LLAMA_API void llama_sample_repetition_penalties(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-               const llama_token * last_tokens,
-                          size_t   penalty_last_n,
-                           float   penalty_repeat,
-                           float   penalty_freq,
-                           float   penalty_present);
-
-    /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
-    /// @param logits Logits extracted from the original generation context.
-    /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
-    /// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
-    LLAMA_API void llama_sample_apply_guidance(
-              struct llama_context * ctx,
-                             float * logits,
-                             float * logits_guidance,
-                             float   scale);
+    // mirror of llama_sampler_i:
+    LLAMA_API const char *           llama_sampler_name  (const struct llama_sampler * smpl);
+    LLAMA_API void                   llama_sampler_accept(      struct llama_sampler * smpl, llama_token token);
+    LLAMA_API void                   llama_sampler_apply (      struct llama_sampler * smpl, llama_token_data_array * cur_p);
+    LLAMA_API void                   llama_sampler_reset (      struct llama_sampler * smpl);
+    LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl);
+    // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
+    LLAMA_API void                   llama_sampler_free  (      struct llama_sampler * smpl);
+
+    // llama_sampler_chain
+    // a type of llama_sampler that can chain multiple samplers one after another
+
+    LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params);
+
+    // important: takes ownership of the sampler object and will free it when llama_sampler_free is called
+    LLAMA_API void                   llama_sampler_chain_add(      struct llama_sampler * chain, struct llama_sampler * smpl);
+    LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i);
+    LLAMA_API int                    llama_sampler_chain_n  (const struct llama_sampler * chain);
+
+    // available samplers:
+
+    LLAMA_API struct llama_sampler * llama_sampler_init_greedy     (void);
+    LLAMA_API struct llama_sampler * llama_sampler_init_dist       (uint32_t seed);
 
     /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
-    LLAMA_API void llama_sample_softmax(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates);
+    LLAMA_API struct llama_sampler * llama_sampler_init_softmax    (void);
 
     /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
-    LLAMA_API void llama_sample_top_k(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                         int32_t   k,
-                          size_t   min_keep);
+    LLAMA_API struct llama_sampler * llama_sampler_init_top_k      (int32_t k);
 
     /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
-    LLAMA_API void llama_sample_top_p(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   p,
-                          size_t   min_keep);
+    LLAMA_API struct llama_sampler * llama_sampler_init_top_p      (float   p, size_t min_keep);
 
     /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
-    LLAMA_API void llama_sample_min_p(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   p,
-                          size_t   min_keep);
+    LLAMA_API struct llama_sampler * llama_sampler_init_min_p      (float   p, size_t min_keep);
 
     /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
-    LLAMA_API void llama_sample_tail_free(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   z,
-                          size_t   min_keep);
+    LLAMA_API struct llama_sampler * llama_sampler_init_tail_free  (float   z, size_t min_keep);
 
     /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
-    LLAMA_API void llama_sample_typical(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   p,
-                          size_t   min_keep);
+    LLAMA_API struct llama_sampler * llama_sampler_init_typical    (float   p, size_t min_keep);
+    LLAMA_API struct llama_sampler * llama_sampler_init_temp       (float   t);
 
-    /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
-    LLAMA_API void llama_sample_entropy(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates_p,
-                           float   min_temp,
-                           float   max_temp,
-                           float   exponent_val);
-
-    LLAMA_API void llama_sample_temp(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   temp);
+    /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
+    LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext   (float   t, float   delta, float exponent);
 
     /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
@@ -1131,36 +1089,57 @@ extern "C" {
     /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
     /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
     /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
-    LLAMA_API llama_token llama_sample_token_mirostat(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   tau,
-                           float   eta,
-                         int32_t   m,
-                           float * mu);
+    LLAMA_API struct llama_sampler * llama_sampler_init_mirostat(
+                             int32_t   n_vocab,
+                            uint32_t   seed,
+                               float   tau,
+                               float   eta,
+                             int32_t   m);
 
     /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
     /// @param tau  The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
     /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
     /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
-    LLAMA_API llama_token llama_sample_token_mirostat_v2(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-                           float   tau,
-                           float   eta,
-                           float * mu);
-
-    /// @details Selects the token with the highest probability.
-    ///          Does not compute the token probabilities. Use llama_sample_softmax() instead.
-    LLAMA_API llama_token llama_sample_token_greedy(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates);
+    LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2(
+                            uint32_t   seed,
+                               float   tau,
+                               float   eta);
+
+    LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
+            const struct llama_model * model,
+                          const char * grammar_str,
+                          const char * grammar_root);
+
+    LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
+                             int32_t   n_vocab,         // llama_n_vocab()
+                         llama_token   special_eos_id,  // llama_token_eos()
+                         llama_token   linefeed_id,     // llama_token_nl()
+                             int32_t   penalty_last_n,  // last n tokens to penalize (0 = disable penalty, -1 = context size)
+                               float   penalty_repeat,  // 1.0 = disabled
+                               float   penalty_freq,    // 0.0 = disabled
+                               float   penalty_present, // 0.0 = disabled
+                                bool   penalize_nl,     // consider newlines as a repeatable token
+                                bool   ignore_eos);     // ignore the end-of-sequence token
+
+    LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
+                             int32_t   n_vocab,
+                             int32_t   n_logit_bias,
+              const llama_logit_bias * logit_bias);
+
+    // Shorthand for:
+    //
+    //    const auto * logits = llama_get_logits_ith(ctx, idx);
+    //    llama_token_data_array cur_p = { ... init from logits ... };
+    //    llama_sampler_apply(smpl, &cur_p);
+    //    return cur_p.data[cur_p.selected].id;
+    //
+    // At this point, this is mostly a convenience function.
+    //
+    LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);
 
-    /// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
-    LLAMA_API llama_token llama_sample_token(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates);
+    // TODO: extend in the future
+    //LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...);
 
     //
     // Model split
@@ -1176,12 +1155,6 @@ extern "C" {
     //  Returns the split_prefix length.
     LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count);
 
-    // Performance information
-    LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
-
-    LLAMA_API void llama_print_timings(struct llama_context * ctx);
-    LLAMA_API void llama_reset_timings(struct llama_context * ctx);
-
     // Print system information
     LLAMA_API const char * llama_print_system_info(void);
 
@@ -1189,65 +1162,24 @@ extern "C" {
     // If this is not called, or NULL is supplied, everything is output on stderr.
     LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
 
-    LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
-
-#ifdef __cplusplus
-}
-#endif
-
-// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
-#ifdef LLAMA_API_INTERNAL
-
-#include <random>
-#include <string>
-#include <vector>
-
-struct ggml_tensor;
-
-const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
-    struct llama_context * ctx
-);
-
-struct llama_partial_utf8 {
-    uint32_t value;    // bit value so far (unshifted)
-    int      n_remain; // num bytes remaining; -1 indicates invalid sequence
-};
-
-struct llama_grammar_candidate {
-    size_t               index;
-    const uint32_t     * code_points;
-    llama_partial_utf8   partial_utf8;
-};
-
-using llama_grammar_rule  = std::vector<      llama_grammar_element>;
-using llama_grammar_stack = std::vector<const llama_grammar_element *>;
-
-using llama_grammar_rules      = std::vector<llama_grammar_rule>;
-using llama_grammar_stacks     = std::vector<llama_grammar_stack>;
-using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
-
-const llama_grammar_rules  & llama_grammar_get_rules (const struct llama_grammar * grammar);
-      llama_grammar_stacks & llama_grammar_get_stacks(      struct llama_grammar * grammar);
-
-void llama_grammar_accept(
-        const llama_grammar_rules  & rules,
-        const llama_grammar_stacks & stacks,
-        const uint32_t chr,
-              llama_grammar_stacks & new_stacks);
+    //
+    // Performance utils
+    //
+    // NOTE: Used by llama.cpp examples, avoid using in third-party apps. Instead, do your own performance measurements.
+    //
 
-std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
-        const llama_grammar_rules & rules,
-        const llama_grammar_stack & stack,
-        const llama_grammar_candidates & candidates);
+    enum llama_perf_type {
+        LLAMA_PERF_TYPE_CONTEXT       = 0,
+        LLAMA_PERF_TYPE_SAMPLER_CHAIN = 1,
+    };
 
-std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
-        const std::string & src,
-        llama_partial_utf8 partial_start);
+    LLAMA_API void llama_perf_print(const void * ctx, enum llama_perf_type type);
+    LLAMA_API void llama_perf_reset(      void * ctx, enum llama_perf_type type);
 
-// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
-// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
-llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
+    LLAMA_API void llama_perf_dump_yaml(FILE * stream, const struct llama_context * ctx);
 
-#endif // LLAMA_API_INTERNAL
+#ifdef __cplusplus
+}
+#endif
 
 #endif // LLAMA_H
index b123d733100ce836878170b7b2048c376bd91655..74e9f64b393b2f2e144f78b1e30830771e91099b 100644 (file)
@@ -3,11 +3,31 @@
 #include "llama-vocab.h"
 #include "llama-sampling.h"
 
+#include <cmath>
 #include <algorithm>
+#include <stdexcept>
 
-// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
-// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
-std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
+//
+// helpers
+//
+
+// NOTE: assumes valid utf8 (but checks for overrun)
+static std::pair<uint32_t, const char *> decode_utf8(const char * src) {
+    static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
+    uint8_t  first_byte = static_cast<uint8_t>(*src);
+    uint8_t  highbits   = first_byte >> 4;
+    int      len        = lookup[highbits];
+    uint8_t  mask       = (1 << (8 - len)) - 1;
+    uint32_t value      = first_byte & mask;
+    const char * end    = src + len; // may overrun!
+    const char * pos    = src + 1;
+    for ( ; pos < end && *pos; pos++) {
+        value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
+    }
+    return std::make_pair(value, pos);
+}
+
+static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
         const std::string & src,
         llama_partial_utf8 partial_start) {
     static const int      lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
@@ -40,7 +60,7 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
     while (*pos != 0) {
         uint8_t first_byte = static_cast<uint8_t>(*pos);
         uint8_t highbits   = first_byte >> 4;
-                n_remain   = lookup[highbits] - 1;
+        n_remain   = lookup[highbits] - 1;
 
         if (n_remain < 0) {
             // invalid sequence, abort
@@ -50,7 +70,7 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
         }
 
         uint8_t mask  = (1 << (7 - n_remain)) - 1;
-                value = first_byte & mask;
+        value = first_byte & mask;
 
         ++pos;
         while (*pos != 0 && n_remain > 0) {
@@ -67,12 +87,510 @@ std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
     return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
 }
 
-const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
-    return grammar->rules;
+static bool is_digit_char(char c) {
+    return '0' <= c && c <= '9';
 }
 
-llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
-    return grammar->stacks;
+static bool is_word_char(char c) {
+    return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c);
+}
+
+static std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
+    const char * pos   = src;
+    const char * end   = src + size;
+    uint32_t     value = 0;
+    for ( ; pos < end && *pos; pos++) {
+        value <<= 4;
+        char c = *pos;
+        if ('a' <= c && c <= 'f') {
+            value += c - 'a' + 10;
+        } else if ('A' <= c && c <= 'F') {
+            value += c - 'A' + 10;
+        } else if ('0' <= c && c <= '9') {
+            value += c - '0';
+        } else {
+            break;
+        }
+    }
+    if (pos != end) {
+        throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
+    }
+    return std::make_pair(value, pos);
+}
+
+static const char * parse_space(const char * src, bool newline_ok) {
+    const char * pos = src;
+    while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
+            (newline_ok && (*pos == '\r' || *pos == '\n'))) {
+        if (*pos == '#') {
+            while (*pos && *pos != '\r' && *pos != '\n') {
+                pos++;
+            }
+        } else {
+            pos++;
+        }
+    }
+    return pos;
+}
+
+static const char * parse_name(const char * src) {
+    const char * pos = src;
+    while (is_word_char(*pos)) {
+        pos++;
+    }
+    if (pos == src) {
+        throw std::runtime_error(std::string("expecting name at ") + src);
+    }
+    return pos;
+}
+
+static const char * parse_int(const char * src) {
+    const char * pos = src;
+    while (is_digit_char(*pos)) {
+        pos++;
+    }
+    if (pos == src) {
+        throw std::runtime_error(std::string("expecting integer at ") + src);
+    }
+    return pos;
+}
+
+static std::pair<uint32_t, const char *> parse_char(const char * src) {
+    if (*src == '\\') {
+        switch (src[1]) {
+            case 'x': return parse_hex(src + 2, 2);
+            case 'u': return parse_hex(src + 2, 4);
+            case 'U': return parse_hex(src + 2, 8);
+            case 't': return std::make_pair('\t', src + 2);
+            case 'r': return std::make_pair('\r', src + 2);
+            case 'n': return std::make_pair('\n', src + 2);
+            case '\\':
+            case '"':
+            case '[':
+            case ']':
+                      return std::make_pair(src[1], src + 2);
+            default:
+                      throw std::runtime_error(std::string("unknown escape at ") + src);
+        }
+    } else if (*src) {
+        return decode_utf8(src);
+    }
+    throw std::runtime_error("unexpected end of input");
+}
+
+static void print_grammar_char(FILE * file, uint32_t c) {
+    if (0x20 <= c && c <= 0x7f) {
+        fprintf(file, "%c", static_cast<char>(c));
+    } else {
+        // cop out of encoding UTF-8
+        fprintf(file, "<U+%04X>", c);
+    }
+}
+
+static bool is_char_element(llama_grammar_element elem) {
+    switch (elem.type) {
+        case LLAMA_GRETYPE_CHAR:           return true;
+        case LLAMA_GRETYPE_CHAR_NOT:       return true;
+        case LLAMA_GRETYPE_CHAR_ALT:       return true;
+        case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
+        case LLAMA_GRETYPE_CHAR_ANY:       return true;
+        default:                           return false;
+    }
+}
+
+static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
+    for (auto elem : rule) {
+        switch (elem.type) {
+            case LLAMA_GRETYPE_END:            fprintf(file, "END");            break;
+            case LLAMA_GRETYPE_ALT:            fprintf(file, "ALT");            break;
+            case LLAMA_GRETYPE_RULE_REF:       fprintf(file, "RULE_REF");       break;
+            case LLAMA_GRETYPE_CHAR:           fprintf(file, "CHAR");           break;
+            case LLAMA_GRETYPE_CHAR_NOT:       fprintf(file, "CHAR_NOT");       break;
+            case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
+            case LLAMA_GRETYPE_CHAR_ALT:       fprintf(file, "CHAR_ALT");       break;
+            case LLAMA_GRETYPE_CHAR_ANY:       fprintf(file, "CHAR_ANY");       break;
+        }
+        switch (elem.type) {
+            case LLAMA_GRETYPE_END:
+            case LLAMA_GRETYPE_ALT:
+            case LLAMA_GRETYPE_RULE_REF:
+                fprintf(file, "(%u) ", elem.value);
+                break;
+            case LLAMA_GRETYPE_CHAR:
+            case LLAMA_GRETYPE_CHAR_NOT:
+            case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+            case LLAMA_GRETYPE_CHAR_ALT:
+            case LLAMA_GRETYPE_CHAR_ANY:
+                fprintf(file, "(\"");
+                print_grammar_char(file, elem.value);
+                fprintf(file, "\") ");
+                break;
+        }
+    }
+    fprintf(file, "\n");
+}
+
+static void print_rule(
+        FILE     * file,
+        uint32_t   rule_id,
+        const llama_grammar_rule & rule,
+        const std::map<uint32_t, std::string> & symbol_id_names) {
+    if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
+        throw std::runtime_error(
+            "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
+    }
+    fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
+    for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
+        llama_grammar_element elem = rule[i];
+        switch (elem.type) {
+            case LLAMA_GRETYPE_END:
+                throw std::runtime_error(
+                    "unexpected end of rule: " + std::to_string(rule_id) + "," +
+                    std::to_string(i));
+            case LLAMA_GRETYPE_ALT:
+                fprintf(file, "| ");
+                break;
+            case LLAMA_GRETYPE_RULE_REF:
+                fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
+                break;
+            case LLAMA_GRETYPE_CHAR:
+                fprintf(file, "[");
+                print_grammar_char(file, elem.value);
+                break;
+            case LLAMA_GRETYPE_CHAR_NOT:
+                fprintf(file, "[^");
+                print_grammar_char(file, elem.value);
+                break;
+            case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+                if (i == 0 || !is_char_element(rule[i - 1])) {
+                    throw std::runtime_error(
+                        "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
+                        std::to_string(rule_id) + "," + std::to_string(i));
+                }
+                fprintf(file, "-");
+                print_grammar_char(file, elem.value);
+                break;
+            case LLAMA_GRETYPE_CHAR_ALT:
+                if (i == 0 || !is_char_element(rule[i - 1])) {
+                    throw std::runtime_error(
+                        "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
+                        std::to_string(rule_id) + "," + std::to_string(i));
+                }
+                print_grammar_char(file, elem.value);
+                break;
+            case LLAMA_GRETYPE_CHAR_ANY:
+                fprintf(file, ".");
+                break;
+        }
+        if (is_char_element(elem)) {
+            switch (rule[i + 1].type) {
+                case LLAMA_GRETYPE_CHAR_ALT:
+                case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+                case LLAMA_GRETYPE_CHAR_ANY:
+                    break;
+                default:
+                    fprintf(file, "] ");
+            }
+        }
+    }
+    fprintf(file, "\n");
+}
+
+//
+// implementation
+//
+
+uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) {
+    uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
+    auto result = symbol_ids.emplace(std::string(src, len), next_id);
+    return result.first->second;
+}
+
+uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) {
+    uint32_t next_id = static_cast<uint32_t>(symbol_ids.size());
+    symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
+    return next_id;
+}
+
+void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) {
+    if (rules.size() <= rule_id) {
+        rules.resize(rule_id + 1);
+    }
+    rules[rule_id] = rule;
+}
+
+const char * llama_grammar_parser::parse_alternates(
+        const char        * src,
+        const std::string & rule_name,
+        uint32_t            rule_id,
+        bool                is_nested) {
+    llama_grammar_rule rule;
+    const char * pos = parse_sequence(src, rule_name, rule, is_nested);
+    while (*pos == '|') {
+        rule.push_back({LLAMA_GRETYPE_ALT, 0});
+        pos = parse_space(pos + 1, true);
+        pos = parse_sequence(pos, rule_name, rule, is_nested);
+    }
+    rule.push_back({LLAMA_GRETYPE_END, 0});
+    add_rule(rule_id, rule);
+    return pos;
+}
+
+const char * llama_grammar_parser::parse_sequence(
+        const char         * src,
+        const std::string  & rule_name,
+        llama_grammar_rule & rule,
+        bool               is_nested) {
+    size_t last_sym_start = rule.size();
+    const char * pos = src;
+
+        auto handle_repetitions = [&](int min_times, int max_times) {
+
+            if (last_sym_start == rule.size()) {
+                throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
+            }
+
+            // apply transformation to previous symbol (last_sym_start to end) according to
+            // the following rewrite rules:
+            // S{m,n} --> S S S (m times) S'(n-m)
+            //            S'(x)   ::= S S'(x-1) |
+            //            (... n-m definitions of these S' rules ...)
+            //            S'(1)   ::= S |
+            // S{m,} -->  S S S (m times) S'
+            //            S'     ::= S S' |
+            // S*     --> S{0,}
+            //        --> S'     ::= S S' |
+            // S+     --> S{1,}
+            //        --> S S'
+            //            S'     ::= S S' |
+            // S?     --> S{0,1}
+            //        --> S'
+            //            S'     ::= S |
+
+            llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
+            if (min_times == 0) {
+                rule.resize(last_sym_start);
+            } else {
+                // Repeat the previous elements (min_times - 1) times
+                for (int i = 1; i < min_times; i++) {
+                    rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
+                }
+            }
+
+            uint32_t last_rec_rule_id = 0;
+            auto n_opt = max_times < 0 ? 1 : max_times - min_times;
+
+            llama_grammar_rule rec_rule(prev_rule);
+            for (int i = 0; i < n_opt; i++) {
+                rec_rule.resize(prev_rule.size());
+                uint32_t rec_rule_id = generate_symbol_id( rule_name);
+                if (i > 0 || max_times < 0) {
+                    rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
+                }
+                rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
+                rec_rule.push_back({LLAMA_GRETYPE_END, 0});
+                add_rule( rec_rule_id, rec_rule);
+                last_rec_rule_id = rec_rule_id;
+            }
+            if (n_opt > 0) {
+                rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
+            }
+        };
+
+        while (*pos) {
+            if (*pos == '"') { // literal string
+                pos++;
+                last_sym_start = rule.size();
+                while (*pos != '"') {
+                    if (!*pos) {
+                        throw std::runtime_error("unexpected end of input");
+                    }
+                    auto char_pair = parse_char(pos);
+                         pos       = char_pair.second;
+                    rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
+                }
+                pos = parse_space(pos + 1, is_nested);
+            } else if (*pos == '[') { // char range(s)
+                pos++;
+                enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
+                if (*pos == '^') {
+                    pos++;
+                    start_type = LLAMA_GRETYPE_CHAR_NOT;
+                }
+                last_sym_start = rule.size();
+                while (*pos != ']') {
+                    if (!*pos) {
+                        throw std::runtime_error("unexpected end of input");
+                    }
+                    auto char_pair = parse_char(pos);
+                         pos       = char_pair.second;
+                    enum llama_gretype type = last_sym_start < rule.size()
+                        ? LLAMA_GRETYPE_CHAR_ALT
+                        : start_type;
+
+                    rule.push_back({type, char_pair.first});
+                    if (pos[0] == '-' && pos[1] != ']') {
+                        if (!pos[1]) {
+                            throw std::runtime_error("unexpected end of input");
+                        }
+                        auto endchar_pair = parse_char(pos + 1);
+                             pos          = endchar_pair.second;
+                        rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
+                    }
+                }
+                pos = parse_space(pos + 1, is_nested);
+            } else if (is_word_char(*pos)) { // rule reference
+                const char * name_end    = parse_name(pos);
+                uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
+                pos = parse_space(name_end, is_nested);
+                last_sym_start = rule.size();
+                rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
+            } else if (*pos == '(') { // grouping
+                // parse nested alternates into synthesized rule
+                pos = parse_space(pos + 1, true);
+                uint32_t sub_rule_id = generate_symbol_id(rule_name);
+                pos = parse_alternates(pos, rule_name, sub_rule_id, true);
+                last_sym_start = rule.size();
+                // output reference to synthesized rule
+                rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
+                if (*pos != ')') {
+                    throw std::runtime_error(std::string("expecting ')' at ") + pos);
+                }
+                pos = parse_space(pos + 1, is_nested);
+            } else if (*pos == '.') { // any char
+                last_sym_start = rule.size();
+                rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
+                pos = parse_space(pos + 1, is_nested);
+            } else if (*pos == '*') {
+                pos = parse_space(pos + 1, is_nested);
+                handle_repetitions(0, -1);
+            } else if (*pos == '+') {
+                pos = parse_space(pos + 1, is_nested);
+                handle_repetitions(1, -1);
+            } else if (*pos == '?') {
+                pos = parse_space(pos + 1, is_nested);
+                handle_repetitions(0, 1);
+            } else if (*pos == '{') {
+                pos = parse_space(pos + 1, is_nested);
+
+                if (!is_digit_char(*pos)) {
+                    throw std::runtime_error(std::string("expecting an int at ") + pos);
+                }
+                const char * int_end = parse_int(pos);
+                int min_times = std::stoul(std::string(pos, int_end - pos));
+                pos = parse_space(int_end, is_nested);
+
+                int max_times = -1;
+
+                if (*pos == '}') {
+                    max_times = min_times;
+                    pos = parse_space(pos + 1, is_nested);
+                } else if (*pos == ',') {
+                    pos = parse_space(pos + 1, is_nested);
+
+                    if (is_digit_char(*pos)) {
+                        const char * int_end = parse_int(pos);
+                        max_times = std::stoul(std::string(pos, int_end - pos));
+                        pos = parse_space(int_end, is_nested);
+                    }
+
+                    if (*pos != '}') {
+                        throw std::runtime_error(std::string("expecting '}' at ") + pos);
+                    }
+                    pos = parse_space(pos + 1, is_nested);
+                } else {
+                    throw std::runtime_error(std::string("expecting ',' at ") + pos);
+                }
+                handle_repetitions(min_times, max_times);
+            } else {
+                break;
+            }
+        }
+        return pos;
+    }
+
+const char * llama_grammar_parser::parse_rule(const char * src) {
+        const char * name_end = parse_name(src);
+        const char * pos      = parse_space(name_end, false);
+        size_t       name_len = name_end - src;
+        uint32_t     rule_id  = get_symbol_id(src, name_len);
+        const std::string name(src, name_len);
+
+        if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
+            throw std::runtime_error(std::string("expecting ::= at ") + pos);
+        }
+        pos = parse_space(pos + 3, true);
+
+        pos = parse_alternates(pos, name, rule_id, false);
+
+        if (*pos == '\r') {
+            pos += pos[1] == '\n' ? 2 : 1;
+        } else if (*pos == '\n') {
+            pos++;
+        } else if (*pos) {
+            throw std::runtime_error(std::string("expecting newline or end at ") + pos);
+        }
+        return parse_space(pos, true);
+    }
+
+bool llama_grammar_parser::parse(const char * src) {
+    try {
+        const char * pos = parse_space(src, true);
+        while (*pos) {
+            pos = parse_rule(pos);
+        }
+        // Validate the state to ensure that all rules are defined
+        for (const auto & rule : rules) {
+            if (rule.empty()) {
+                throw std::runtime_error("Undefined rule");
+            }
+            for (const auto & elem : rule) {
+                if (elem.type == LLAMA_GRETYPE_RULE_REF) {
+                    // Ensure that the rule at that location exists
+                    if (elem.value >= rules.size() || rules[elem.value].empty()) {
+                        // Get the name of the rule that is missing
+                        for (const auto & kv : symbol_ids) {
+                            if (kv.second == elem.value) {
+                                throw std::runtime_error("Undefined rule identifier '" + kv.first + "'");
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } catch (const std::exception & err) {
+        fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
+        rules.clear();
+        return false;
+    }
+
+    return true;
+}
+
+void llama_grammar_parser::print(FILE * file) {
+    try {
+        std::map<uint32_t, std::string> symbol_id_names;
+        for (const auto & kv : symbol_ids) {
+            symbol_id_names[kv.second] = kv.first;
+        }
+        for (size_t i = 0, end = rules.size(); i < end; i++) {
+            // fprintf(file, "%zu: ", i);
+            // print_rule_binary(file, rules[i]);
+            print_rule(file, uint32_t(i), rules[i], symbol_id_names);
+            // fprintf(file, "\n");
+        }
+    } catch (const std::exception & err) {
+        fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
+    }
+}
+
+llama_grammar_stack llama_grammar_parser::c_rules() const {
+    llama_grammar_stack ret;
+    ret.reserve(rules.size());
+    for (const auto & rule : rules) {
+        ret.push_back(rule.data());
+    }
+    return ret;
 }
 
 // returns true iff pos points to the end of one of the definitions of a rule
@@ -89,7 +607,6 @@ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos)
 static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
         const llama_grammar_element * pos,
         const uint32_t                chr) {
-
     bool found            = false;
     bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
 
@@ -225,16 +742,93 @@ static void llama_grammar_advance_stack(
     }
 }
 
-// takes a set of possible pushdown stacks on a grammar, which are required to
-// be positioned at a character range (see `llama_grammar_advance_stack`), and
-// produces the N possible stacks if the given char is accepted at those
-// positions
+static llama_grammar_candidates llama_grammar_reject_candidates(
+        const llama_grammar_rules      & rules,
+        const llama_grammar_stacks     & stacks,
+        const llama_grammar_candidates & candidates) {
+    GGML_ASSERT(!stacks.empty()); // REVIEW
+
+    if (candidates.empty()) {
+        return {};
+    }
+
+    auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
+
+    for (size_t i = 1, size = stacks.size(); i < size; ++i) {
+        rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
+    }
+
+    return rejects;
+}
+
+static bool llama_grammar_detect_left_recursion(
+        const llama_grammar_rules & rules,
+        size_t rule_index,
+        std::vector<bool> * rules_visited,
+        std::vector<bool> * rules_in_progress,
+        std::vector<bool> * rules_may_be_empty) {
+    if ((*rules_in_progress)[rule_index]) {
+        return true;
+    }
+
+    (*rules_in_progress)[rule_index] = true;
+
+    const llama_grammar_rule & rule = rules[rule_index];
+
+    // First check if the rule might produce the empty string. This could be done combined with the second
+    // step but it's more readable as two steps.
+    bool at_rule_start = true;
+    for (size_t i = 0; i < rule.size(); i++) {
+        if (llama_grammar_is_end_of_sequence(&rule[i])) {
+            if (at_rule_start) {
+                (*rules_may_be_empty)[rule_index] = true;
+                break;
+            }
+            at_rule_start = true;
+        } else {
+            at_rule_start = false;
+        }
+    }
+
+    // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
+    // be empty)
+    bool recurse_into_nonterminal = true;
+    for (size_t i = 0; i < rule.size(); i++) {
+        if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
+            if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
+                return true;
+            }
+            if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
+                recurse_into_nonterminal = false;
+            }
+        } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
+            recurse_into_nonterminal = true;
+        } else {
+            recurse_into_nonterminal = false;
+        }
+    }
+
+    (*rules_in_progress)[rule_index] = false;
+    (*rules_visited)[rule_index] = true;
+
+    return false;
+}
+
+const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) {
+    return grammar->rules;
+}
+
+llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) {
+    return grammar->stacks;
+}
+
 void llama_grammar_accept(
         const llama_grammar_rules  & rules,
         const llama_grammar_stacks & stacks,
         const uint32_t               chr,
-              llama_grammar_stacks & new_stacks) {
-    new_stacks.clear();
+              llama_grammar_stacks & stacks_new) {
+    stacks_new.clear();
+    stacks_new.reserve(stacks.size());
 
     for (const auto & stack : stacks) {
         if (stack.empty()) {
@@ -250,29 +844,11 @@ void llama_grammar_accept(
             if (!llama_grammar_is_end_of_sequence(pos)) {
                 new_stack.push_back(pos);
             }
-            llama_grammar_advance_stack(rules, new_stack, new_stacks);
+            llama_grammar_advance_stack(rules, new_stack, stacks_new);
         }
     }
 }
 
-static llama_grammar_candidates llama_grammar_reject_candidates(
-        const llama_grammar_rules  & rules,
-        const llama_grammar_stacks & stacks,
-        const llama_grammar_candidates & candidates) {
-    GGML_ASSERT(!stacks.empty()); // REVIEW
-
-    if (candidates.empty()) {
-        return {};
-    }
-
-    auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
-
-    for (size_t i = 1, size = stacks.size(); i < size; ++i) {
-        rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
-    }
-    return rejects;
-}
-
 llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
         const llama_grammar_rules      & rules,
         const llama_grammar_stack      & stack,
@@ -328,72 +904,97 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
     return rejects;
 }
 
-static bool llama_grammar_detect_left_recursion(
-        const llama_grammar_rules & rules,
-        size_t rule_index,
-        std::vector<bool> * rules_visited,
-        std::vector<bool> * rules_in_progress,
-        std::vector<bool> * rules_may_be_empty) {
-    if ((*rules_in_progress)[rule_index]) {
-        return true;
-    }
+////////////////////
 
-    (*rules_in_progress)[rule_index] = true;
+struct llama_grammar * llama_grammar_init_impl(
+        const struct llama_vocab * vocab,
+        const llama_grammar_element ** rules,
+        size_t n_rules,
+        size_t start_rule_index) {
+    const llama_grammar_element * pos;
 
-    const llama_grammar_rule & rule = rules[rule_index];
+    // copy rule definitions into vectors
+    llama_grammar_rules vec_rules(n_rules);
+    for (size_t i = 0; i < n_rules; i++) {
+        for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
+            vec_rules[i].push_back(*pos);
+        }
+        vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
+    }
 
-    // First check if the rule might produce the empty string. This could be done combined with the second
-    // step but it's more readable as two steps.
-    bool at_rule_start = true;
-    for (size_t i = 0; i < rule.size(); i++) {
-        if (llama_grammar_is_end_of_sequence(&rule[i])) {
-            if (at_rule_start) {
-                (*rules_may_be_empty)[rule_index] = true;
-                break;
-            }
-            at_rule_start = true;
-        } else {
-            at_rule_start = false;
+    // Check for left recursion
+    std::vector<bool> rules_visited(n_rules);
+    std::vector<bool> rules_in_progress(n_rules);
+    std::vector<bool> rules_may_be_empty(n_rules);
+    for (size_t i = 0; i < n_rules; i++) {
+        if (rules_visited[i]) {
+            continue;
+        }
+        if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
+            LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
+            return nullptr;
         }
     }
 
-    // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
-    // be empty)
-    bool recurse_into_nonterminal = true;
-    for (size_t i = 0; i < rule.size(); i++) {
-        if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
-            if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
-                return true;
-            }
-            if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
-                recurse_into_nonterminal = false;
-            }
-        } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
-            recurse_into_nonterminal = true;
+    // loop over alternates of start rule to build initial stacks
+    llama_grammar_stacks stacks;
+    pos = vec_rules[start_rule_index].data();
+    do {
+        llama_grammar_stack stack;
+        if (!llama_grammar_is_end_of_sequence(pos)) {
+            // if alternate is nonempty, add to stack
+            stack.push_back(pos);
+        }
+        llama_grammar_advance_stack(vec_rules, stack, stacks);
+        while (!llama_grammar_is_end_of_sequence(pos)) {
+            // scan to end of alternate def
+            pos++;
+        }
+        if (pos->type == LLAMA_GRETYPE_ALT) {
+            // there's another alternate def of this rule to process
+            pos++;
         } else {
-            recurse_into_nonterminal = false;
+            break;
         }
-    }
+    } while (true);
 
-    (*rules_in_progress)[rule_index] = false;
-    (*rules_visited)[rule_index] = true;
-    return false;
+    // Important: vec_rules has to be moved here, not copied, because stacks contains
+    // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
+    // then the pointers would be invalidated when the local vec_rules goes out of scope.
+    return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
 }
 
-//
-// grammar - external
-//
+struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
+    llama_grammar_parser parser;
+
+    // if there is a grammar, parse it
+    if (!parser.parse(grammar_str)) {
+        return nullptr;
+    }
+
+    // will be empty (default) if there are parse errors
+    if (parser.rules.empty()) {
+        fprintf(stderr, "%s: failed to parse grammar\n", __func__);
+        return nullptr;
+    }
+
+    // Ensure that there is a "root" node.
+    if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) {
+        fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__);
+        return nullptr;
+    }
+
+    std::vector<const llama_grammar_element *> grammar_rules(parser.c_rules());
+
+    const size_t n_rules = grammar_rules.size();
+    const size_t start_rule_index = parser.symbol_ids.at(grammar_root);
 
-struct llama_grammar * llama_grammar_init_impl(
-            const llama_grammar_element ** rules,
-                                 size_t    n_rules,
-                                 size_t    start_rule_index) {
     const llama_grammar_element * pos;
 
     // copy rule definitions into vectors
     llama_grammar_rules vec_rules(n_rules);
     for (size_t i = 0; i < n_rules; i++) {
-        for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
+        for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
             vec_rules[i].push_back(*pos);
         }
         vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
@@ -438,22 +1039,26 @@ struct llama_grammar * llama_grammar_init_impl(
     // Important: vec_rules has to be moved here, not copied, because stacks contains
     // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
     // then the pointers would be invalidated when the local vec_rules goes out of scope.
-    return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
+    return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
 }
 
 void llama_grammar_free_impl(struct llama_grammar * grammar) {
+    if (grammar == nullptr) {
+        return;
+    }
+
     delete grammar;
 }
 
-struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) {
-    llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
+struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
+    llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
 
     // redirect elements in stacks to point to new rules
     for (size_t is = 0; is < result->stacks.size(); is++) {
         for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
-            for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
-                for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
-                    if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
+            for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
+                for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
+                    if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
                          result->stacks[is][ie]  =  &result->rules[ir0][ir1];
                     }
                 }
@@ -464,14 +1069,11 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram
     return result;
 }
 
-void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
-    GGML_ASSERT(grammar);
-    GGML_ASSERT(vocab);
-
-    int64_t t_start_sample_us = ggml_time_us();
+void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) {
+    GGML_ASSERT(grammar.vocab != nullptr);
 
     bool allow_eog = false;
-    for (const auto & stack : grammar->stacks) {
+    for (const auto & stack : grammar.stacks) {
         if (stack.empty()) {
             allow_eog = true;
             break;
@@ -479,40 +1081,38 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc
     }
 
     std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
-    candidates_decoded.reserve(candidates->size);
+    candidates_decoded.reserve(cur_p->size);
 
     llama_grammar_candidates candidates_grammar;
-    candidates_grammar.reserve(candidates->size);
+    candidates_grammar.reserve(cur_p->size);
 
-    for (size_t i = 0; i < candidates->size; ++i) {
-        const llama_token id      = candidates->data[i].id;
-        const std::string & piece = vocab->cache_token_to_piece.at(id);
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        const llama_token id      = cur_p->data[i].id;
+        const std::string & piece = grammar.vocab->cache_token_to_piece.at(id);
 
-        if (llama_token_is_eog_impl(*vocab, id)) {
+        if (llama_token_is_eog_impl(*grammar.vocab, id)) {
             if (!allow_eog) {
-                candidates->data[i].logit = -INFINITY;
+                cur_p->data[i].logit = -INFINITY;
             }
         } else if (piece.empty() || piece[0] == 0) {
-            candidates->data[i].logit = -INFINITY;
+            cur_p->data[i].logit = -INFINITY;
         } else {
-            candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
+            candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
             candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
         }
     }
 
-    const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
+    const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
     for (const auto & reject : rejects) {
-        candidates->data[reject.index].logit = -INFINITY;
+        cur_p->data[reject.index].logit = -INFINITY;
     }
-
-    smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
 }
 
-void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
-    const int64_t t_start_sample_us = ggml_time_us();
+void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
+    GGML_ASSERT(grammar.vocab != nullptr);
 
-    if (llama_token_is_eog_impl(*vocab, token)) {
-        for (const auto & stack : grammar->stacks) {
+    if (llama_token_is_eog_impl(*grammar.vocab, token)) {
+        for (const auto & stack : grammar.stacks) {
             if (stack.empty()) {
                 return;
             }
@@ -520,20 +1120,19 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc
         GGML_ABORT("fatal error");
     }
 
-    const std::string & piece = vocab->cache_token_to_piece.at(token);
+    const std::string & piece = grammar.vocab->cache_token_to_piece.at(token);
 
     // Note terminating 0 in decoded string
-    const auto   decoded     = decode_utf8(piece, grammar->partial_utf8);
+    const auto   decoded     = decode_utf8(piece, grammar.partial_utf8);
     const auto & code_points = decoded.first;
 
-    llama_grammar_stacks tmp_new_stacks;
+    llama_grammar_stacks stacks_new;
+
     for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
-        llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
-        grammar->stacks = tmp_new_stacks;
+        llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
+        grammar.stacks = std::move(stacks_new);
     }
 
-    grammar->partial_utf8 = decoded.second;
-    GGML_ASSERT(!grammar->stacks.empty());
-
-    smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    grammar.partial_utf8 = decoded.second;
+    GGML_ASSERT(!grammar.stacks.empty());
 }
index 695ea0632bb84c698db84e833a615d32199821c1..f529ce351e4167d03cbeb538018047a6287c1c02 100644 (file)
 
 #include "llama-impl.h"
 
+#include <map>
+
 struct llama_vocab;
-struct llama_sampling;
+
+// grammar element type
+enum llama_gretype {
+    // end of rule definition
+    LLAMA_GRETYPE_END            = 0,
+
+    // start of alternate definition for rule
+    LLAMA_GRETYPE_ALT            = 1,
+
+    // non-terminal element: reference to rule
+    LLAMA_GRETYPE_RULE_REF       = 2,
+
+    // terminal element: character (code point)
+    LLAMA_GRETYPE_CHAR           = 3,
+
+    // inverse char(s) ([^a], [^a-b] [^abc])
+    LLAMA_GRETYPE_CHAR_NOT       = 4,
+
+    // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
+    // be an inclusive range ([a-z])
+    LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
+
+    // modifies a preceding LLAMA_GRETYPE_CHAR or
+    // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
+    LLAMA_GRETYPE_CHAR_ALT       = 6,
+
+    // any character (.)
+    LLAMA_GRETYPE_CHAR_ANY       = 7,
+};
+
+typedef struct llama_grammar_element {
+    enum llama_gretype type;
+    uint32_t           value; // Unicode code point or rule ID
+} llama_grammar_element;
+
+struct llama_partial_utf8 {
+    uint32_t value;    // bit value so far (unshifted)
+    int      n_remain; // num bytes remaining; -1 indicates invalid sequence
+};
+
+struct llama_grammar_candidate {
+    size_t               index;
+    const uint32_t     * code_points;
+    llama_partial_utf8   partial_utf8;
+};
+
+using llama_grammar_rule  = std::vector<      llama_grammar_element>;
+using llama_grammar_stack = std::vector<const llama_grammar_element *>;
+
+using llama_grammar_rules      = std::vector<llama_grammar_rule>;
+using llama_grammar_stacks     = std::vector<llama_grammar_stack>;
+using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
+
+const llama_grammar_rules  & llama_grammar_get_rules (const struct llama_grammar * grammar);
+      llama_grammar_stacks & llama_grammar_get_stacks(      struct llama_grammar * grammar);
+
+// takes a set of possible pushdown stacks on a grammar, which are required to
+// be positioned at a character range (see `llama_grammar_advance_stack`), and
+// produces the N possible stacks if the given char is accepted at those
+// positions
+void llama_grammar_accept(
+        const llama_grammar_rules  & rules,
+        const llama_grammar_stacks & stacks,
+                          uint32_t   chr,
+              llama_grammar_stacks & stacks_new);
+
+std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
+        const llama_grammar_rules      & rules,
+        const llama_grammar_stack      & stack,
+        const llama_grammar_candidates & candidates);
+
+struct llama_grammar_parser {
+    std::map<std::string, uint32_t> symbol_ids;
+
+    llama_grammar_rules rules;
+
+    llama_grammar_stack c_rules() const;
+
+    uint32_t get_symbol_id(const char * src, size_t len);
+    uint32_t generate_symbol_id(const std::string & base_name);
+
+    void add_rule(uint32_t rule_id, const llama_grammar_rule & rule);
+
+    const char * parse_alternates(
+            const char        * src,
+            const std::string & rule_name,
+            uint32_t            rule_id,
+            bool                is_nested);
+
+    const char * parse_sequence(
+            const char         * src,
+            const std::string  & rule_name,
+            llama_grammar_rule & rule,
+            bool               is_nested);
+
+    const char * parse_rule(const char * src);
+
+    bool parse(const char * src);
+    void print(FILE * file);
+};
 
 struct llama_grammar {
-    const llama_grammar_rules  rules;
+    // note: allow null vocab for testing (not great)
+    const llama_vocab * vocab;
+
+    const llama_grammar_rules  rules;  // TODO: shared ptr
           llama_grammar_stacks stacks;
 
     // buffer for partially generated UTF-8 sequence from accepted tokens
@@ -17,23 +121,24 @@ struct llama_grammar {
 // internal API
 //
 
+// note: needed for tests (not great)
 struct llama_grammar * llama_grammar_init_impl(
-            const llama_grammar_element ** rules,
-                                 size_t    n_rules,
-                                 size_t    start_rule_index);
+        const struct llama_vocab * vocab,
+        const llama_grammar_element ** rules,
+        size_t n_rules,
+        size_t start_rule_index);
+
+struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root);
 
 void llama_grammar_free_impl(struct llama_grammar * grammar);
 
-struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar);
+struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar);
 
-void llama_grammar_sample_impl(
-        const struct llama_grammar * grammar,
-          const struct llama_vocab * vocab,
-       const struct llama_sampling * smpl,
-            llama_token_data_array * candidates);
+// TODO: move the API below as member functions of llama_grammar
+void llama_grammar_apply_impl(
+        const struct llama_grammar & grammar,
+            llama_token_data_array * cur_p);
 
-void llama_grammar_accept_token_impl(
-              struct llama_grammar * grammar,
-          const struct llama_vocab * vocab,
-       const struct llama_sampling * smpl,
+void llama_grammar_accept_impl(
+              struct llama_grammar & grammar,
                        llama_token   token);
index 9527740961da652657d15b898d973ad8aac6d33c..fa2e09e1f688e6b5c632af9baa9d7bef800738cc 100644 (file)
@@ -1,8 +1,11 @@
 #pragma once
 
-#define LLAMA_API_INTERNAL
 #include "llama.h"
 
+#include <string>
+#include <vector>
+#include <stdexcept>
+
 #ifdef __GNUC__
 #ifdef __MINGW32__
 #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
@@ -29,6 +32,20 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
 // helpers
 //
 
+struct time_meas {
+    time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
+
+    ~time_meas() {
+        if (t_start_us >= 0) {
+            t_acc += ggml_time_us() - t_start_us;
+        }
+    }
+
+    const int64_t t_start_us;
+
+    int64_t & t_acc;
+};
+
 static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
     if (search.empty()) {
         return;
@@ -45,3 +62,113 @@ static void replace_all(std::string & s, const std::string & search, const std::
     builder.append(s, last_pos, std::string::npos);
     s = std::move(builder);
 }
+
+const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
+    struct llama_context * ctx
+);
+
+// the ring buffer works similarly to std::deque, but with a fixed capacity
+template<typename T>
+struct ring_buffer {
+    ring_buffer(size_t cap) : capacity(cap), data(cap) {}
+
+    T & front() {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        return data[first];
+    }
+
+    const T & front() const {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        return data[first];
+    }
+
+    T & back() {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        return data[pos];
+    }
+
+    const T & back() const {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        return data[pos];
+    }
+
+    void push_back(const T & value) {
+        if (sz == capacity) {
+            // advance the start when buffer is full
+            first = (first + 1) % capacity;
+        } else {
+            sz++;
+        }
+        data[pos] = value;
+        pos = (pos + 1) % capacity;
+    }
+
+    T pop_front() {
+        if (sz == 0) {
+            throw std::runtime_error("ring buffer is empty");
+        }
+        T value = data[first];
+        first = (first + 1) % capacity;
+        sz--;
+        return value;
+    }
+
+    //T & operator[](size_t i) {
+    //    if (i >= sz) {
+    //        throw std::runtime_error("ring buffer: index out of bounds");
+    //    }
+    //    return data[(first + i) % capacity];
+    //}
+
+    //const T & at(size_t i) const {
+    //    if (i >= sz) {
+    //        throw std::runtime_error("ring buffer: index out of bounds");
+    //    }
+    //    return data[(first + i) % capacity];
+    //}
+
+    const T & rat(size_t i) const {
+        if (i >= sz) {
+            throw std::runtime_error("ring buffer: index out of bounds");
+        }
+        return data[(first + sz - i - 1) % capacity];
+    }
+
+    std::vector<T> to_vector() const {
+        std::vector<T> result;
+        result.reserve(sz);
+        for (size_t i = 0; i < sz; i++) {
+            result.push_back(data[(first + i) % capacity]);
+        }
+        return result;
+    }
+
+    void clear() {
+        // here only reset the status of the buffer
+        sz = 0;
+        first = 0;
+        pos = 0;
+    }
+
+    bool empty() const {
+        return sz == 0;
+    }
+
+    size_t size() const {
+        return sz;
+    }
+
+    size_t capacity = 0;
+    size_t sz = 0;
+    size_t first = 0;
+    size_t pos = 0;
+    std::vector<T> data;
+};
index 8f4841d9daf7b90f681eca1d0c989e761d205181..61f4cbb9217e88a04a2c6d09f89b59222ddbcae0 100644 (file)
@@ -1,12 +1,28 @@
 #include "llama-sampling.h"
 
+#include "llama-vocab.h"
+#include "llama-grammar.h"
+
+#include <cassert>
 #include <algorithm>
 #include <cstring>
 #include <ctime>
 #include <cfloat>
 #include <numeric>
+#include <random>
 #include <unordered_map>
 
+static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector<float> & probs) {
+    probs.resize(cur_p->size);
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        probs[i] = cur_p->data[i].p;
+    }
+
+    std::discrete_distribution<size_t> dist(probs.begin(), probs.end());
+
+    return dist(rng);
+}
+
 static void llama_log_softmax(float * array, size_t size) {
     float max_l = *std::max_element(array, array + size);
     float sum = 0.f;
@@ -21,65 +37,50 @@ static void llama_log_softmax(float * array, size_t size) {
     }
 }
 
-void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
-    if (seed == LLAMA_DEFAULT_SEED) {
-        seed = time(NULL);
-    }
-
-    smpl->rng.seed(seed);
-}
-
-void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
-    GGML_ASSERT(candidates->size > 0);
-
-    const int64_t t_start_sample_us = ggml_time_us();
+static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
+    GGML_ASSERT(cur_p->size > 0);
 
     // Sort the logits in descending order
-    if (!candidates->sorted) {
-        std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
+    if (!cur_p->sorted) {
+        std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
             return a.logit > b.logit;
         });
-        candidates->sorted = true;
+        cur_p->sorted = true;
     }
 
-    float max_l = candidates->data[0].logit;
+    float max_l = cur_p->data[0].logit;
     float cum_sum = 0.0f;
-    for (size_t i = 0; i < candidates->size; ++i) {
-        float p = expf(candidates->data[i].logit - max_l);
-        candidates->data[i].p = p;
+
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        float p = expf(cur_p->data[i].logit - max_l);
+        cur_p->data[i].p = p;
         cum_sum += p;
     }
-    for (size_t i = 0; i < candidates->size; ++i) {
-        candidates->data[i].p /= cum_sum;
-    }
 
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        cur_p->data[i].p /= cum_sum;
     }
 }
 
-void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
+static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
     // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
-    // if (k >= (int32_t)candidates->size) {
+    // if (k >= (int32_t)cur_p->size) {
     //     return;
     // }
 
-    const int64_t t_start_sample_us = ggml_time_us();
-
     if (k <= 0) {
-        k = candidates->size;
+        k = cur_p->size;
     }
 
-    k = std::max(k, (int) min_keep);
-    k = std::min(k, (int) candidates->size);
+    k = std::min(k, (int) cur_p->size);
 
     // Sort scores in descending order
-    if (!candidates->sorted) {
+    if (!cur_p->sorted) {
         auto comp = [](const llama_token_data & a, const llama_token_data & b) {
             return a.logit > b.logit;
         };
         if (k <= 128) {
-            std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
+            std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
         } else {
             constexpr int   nbuckets     = 128;
             constexpr float bucket_low   = -10.0f;
@@ -87,11 +88,11 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
             constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
             constexpr float bucket_inter = -bucket_low * bucket_scale;
 
-            std::vector<int> bucket_idx(candidates->size);
+            std::vector<int> bucket_idx(cur_p->size);
             std::vector<int> histo(nbuckets, 0);
 
-            for (int i = 0; i < (int)candidates->size; ++i) {
-                const float val = candidates->data[i].logit;
+            for (int i = 0; i < (int)cur_p->size; ++i) {
+                const float val = cur_p->data[i].logit;
                 int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
                 ib = std::max(0, std::min(nbuckets-1, ib));
                 bucket_idx[i] = ib;
@@ -101,20 +102,22 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
             int ib = nbuckets - 1;
             for ( ; ib >= 0; --ib) {
                 nhave += histo[ib];
-                if (nhave >= k) break;
+                if (nhave >= k) {
+                    break;
+                }
             }
             std::vector<llama_token_data> tmp_tokens(nhave);
-            auto ptr = tmp_tokens.data();
+            auto ptr = tmp_tokens.data();
             std::vector<llama_token_data*> bucket_ptrs;
             bucket_ptrs.reserve(nbuckets - ib);
             for (int j = nbuckets - 1; j >= ib; --j) {
                 bucket_ptrs.push_back(ptr);
                 ptr += histo[j];
             }
-            for (int i = 0; i < (int)candidates->size; ++i) {
+            for (int i = 0; i < (int)cur_p->size; ++i) {
                 int j = bucket_idx[i];
                 if (j >= ib) {
-                    *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
+                    *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
                 }
             }
 
@@ -127,33 +130,27 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra
             }
             std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
 
-            std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
+            std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
 
         }
-        candidates->sorted = true;
-    }
-    candidates->size = k;
-
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+        cur_p->sorted = true;
     }
+    cur_p->size = k;
 }
 
-void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
+static void llama_sampler_top_p_impl(llama_token_data_array * cur_p, float p, size_t min_keep) {
     if (p >= 1.0f) {
         return;
     }
 
-    llama_sample_softmax_impl(smpl, candidates);
-
-    const int64_t t_start_sample_us = ggml_time_us();
+    llama_sampler_softmax_impl(cur_p);
 
     // Compute the cumulative probabilities
     float cum_sum = 0.0f;
-    size_t last_idx = candidates->size;
+    size_t last_idx = cur_p->size;
 
-    for (size_t i = 0; i < candidates->size; ++i) {
-        cum_sum += candidates->data[i].p;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        cum_sum += cur_p->data[i].p;
 
         // Check if the running sum is at least p or if we have kept at least min_keep tokens
         // we set the last index to i+1 to indicate that the current iterate should be included in the set
@@ -164,88 +161,77 @@ void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_arra
     }
 
     // Resize the output vector to keep only the top-p tokens
-    candidates->size = last_idx;
-
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
+    cur_p->size = last_idx;
 }
 
-void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
-    if (p <= 0.0f || !candidates->size) {
+static void llama_sampler_min_p_impl(llama_token_data_array * cur_p, float p, size_t min_keep) {
+    if (p <= 0.0f || !cur_p->size) {
         return;
     }
 
-    const int64_t t_start_sample_us = ggml_time_us();
-
     bool min_p_applied = false;
 
-    // if the candidates aren't sorted, try the unsorted implementation first
-    if (!candidates->sorted) {
+    // if the cur_p aren't sorted, try the unsorted implementation first
+    if (!cur_p->sorted) {
         std::vector<llama_token_data> filtered_tokens;
 
         float max_logit = -FLT_MAX;
-        for (size_t i = 0; i < candidates->size; ++i) {
-            max_logit = std::max(max_logit, candidates->data[i].logit);
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            max_logit = std::max(max_logit, cur_p->data[i].logit);
         }
         const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
 
-        for (size_t i = 0; i < candidates->size; ++i) {
-            if (candidates->data[i].logit >= min_logit) {
-                filtered_tokens.push_back(candidates->data[i]);
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            if (cur_p->data[i].logit >= min_logit) {
+                filtered_tokens.push_back(cur_p->data[i]);
             }
         }
 
         // if we have enough values the operation was a success
         if (filtered_tokens.size() >= min_keep) {
-            memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
-            candidates->size = filtered_tokens.size();
+            memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
+            cur_p->size = filtered_tokens.size();
             min_p_applied = true;
         }
     }
 
-    // if the candidates are sorted or the unsorted implementation failed, use this implementation
+    // if the cur_p are sorted or the unsorted implementation failed, use this implementation
     if (!min_p_applied) {
         // Sort the logits in descending order
-        if (!candidates->sorted) {
-            std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
+        if (!cur_p->sorted) {
+            std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
                 return a.logit > b.logit;
             });
-            candidates->sorted = true;
+            cur_p->sorted = true;
         }
 
-        const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
+        const float min_logit = cur_p->data[0].logit + logf(p); // min logit for p_i >= p * p_max
         size_t i = 1; // first token always matches
 
-        for (; i < candidates->size; ++i) {
-            if (candidates->data[i].logit < min_logit && i >= min_keep) {
+        for (; i < cur_p->size; ++i) {
+            if (cur_p->data[i].logit < min_logit && i >= min_keep) {
                 break; // prob too small
             }
         }
 
         // Resize the output vector to keep only the matching tokens
-        candidates->size = i;
-    }
-
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+        cur_p->size = i;
     }
 }
 
-void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
-    if (z >= 1.0f || candidates->size <= 2) {
+static void llama_sampler_tail_free_impl(llama_token_data_array * cur_p, float z, size_t min_keep) {
+    if (z >= 1.0f || cur_p->size <= 2) {
         return;
     }
 
-    llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
-    const int64_t t_start_sample_us = ggml_time_us();
+    llama_sampler_softmax_impl(cur_p);
 
     // Compute the first and second derivatives
-    std::vector<float> first_derivatives(candidates->size - 1);
-    std::vector<float> second_derivatives(candidates->size - 2);
+    std::vector<float> first_derivatives(cur_p->size - 1);
+    std::vector<float> second_derivatives(cur_p->size - 2);
 
     for (size_t i = 0; i < first_derivatives.size(); ++i) {
-        first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
+        first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
     }
     for (size_t i = 0; i < second_derivatives.size(); ++i) {
         second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
@@ -272,7 +258,7 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_
     }
 
     float cum_sum = 0.0f;
-    size_t last_idx = candidates->size;
+    size_t last_idx = cur_p->size;
     for (size_t i = 0; i < second_derivatives.size(); ++i) {
         cum_sum += second_derivatives[i];
 
@@ -284,14 +270,10 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_
     }
 
     // Resize the output vector to keep only the tokens above the tail location
-    candidates->size = last_idx;
-
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
+    cur_p->size = last_idx;
 }
 
-void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
+static void llama_sampler_typical_impl(llama_token_data_array * cur_p, float p, size_t min_keep) {
     // Reference implementation:
     // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
     if (p >= 1.0f) {
@@ -299,24 +281,22 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
     }
 
     // Compute the softmax of logits and calculate entropy
-    llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
-
-    const int64_t t_start_sample_us = ggml_time_us();
+    llama_sampler_softmax_impl(cur_p);
 
     float entropy = 0.0f;
-    for (size_t i = 0; i < candidates->size; ++i) {
-        entropy += -candidates->data[i].p * logf(candidates->data[i].p);
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
     }
 
     // Compute the absolute difference between negative log probability and entropy for each candidate
     std::vector<float> shifted_scores;
-    for (size_t i = 0; i < candidates->size; ++i) {
-        float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
         shifted_scores.push_back(shifted_score);
     }
 
     // Sort tokens based on the shifted_scores and their corresponding indices
-    std::vector<size_t> indices(candidates->size);
+    std::vector<size_t> indices(cur_p->size);
     std::iota(indices.begin(), indices.end(), 0);
 
     std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
@@ -329,7 +309,7 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
 
     for (size_t i = 0; i < indices.size(); ++i) {
         size_t idx = indices[i];
-        cum_sum += candidates->data[idx].p;
+        cum_sum += cur_p->data[idx].p;
 
         // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
         if (cum_sum > p && i >= min_keep - 1) {
@@ -339,45 +319,39 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar
     }
 
     // Resize the output vector to keep only the locally typical tokens
-    std::vector<llama_token_data> new_candidates;
+    std::vector<llama_token_data> cur_p_new;
     for (size_t i = 0; i < last_idx; ++i) {
         size_t idx = indices[i];
-        new_candidates.push_back(candidates->data[idx]);
+        cur_p_new.push_back(cur_p->data[idx]);
     }
 
-    // Replace the data in candidates with the new_candidates data
-    std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
-    candidates->size = new_candidates.size();
-    candidates->sorted = false;
-
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
+    // Replace the data in cur_p with the cur_p_new data
+    std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
+    cur_p->size = cur_p_new.size();
+    cur_p->sorted = false;
 }
 
-void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
-    const int64_t t_start_sample_us = ggml_time_us();
-
+static void llama_sampler_entropy_impl(llama_token_data_array * cur_p, float min_temp, float max_temp, float exponent_val) {
     // no need to do anything if there is only one (or zero) candidates
-    if(candidates->size <= 1) {
+    if (cur_p->size <= 1) {
         return;
     }
 
     // Calculate maximum possible entropy
-    float max_entropy = -logf(1.0f / candidates->size);
+    float max_entropy = -logf(1.0f / cur_p->size);
 
-    llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
+    llama_sampler_softmax_impl(cur_p);
 
     // Calculate entropy of the softmax probabilities
     float entropy = 0.0f;
-    for (size_t i = 0; i < candidates->size; ++i) {
-        float prob = candidates->data[i].p;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        float prob = cur_p->data[i].p;
         if (prob > 0.0f) { // Ensure no log(0)
             entropy -= prob * logf(prob);
         }
     }
 
-    // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above)
+    // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
     float normalized_entropy = entropy / max_entropy;
 
     // Map the normalized entropy to the desired temperature range using the power function
@@ -393,70 +367,52 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar
 #endif
 
     // Apply the dynamically calculated temperature scaling
-    for (size_t i = 0; i < candidates->size; ++i) {
-        candidates->data[i].logit /= dyn_temp;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        cur_p->data[i].logit /= dyn_temp;
     }
 
     // Re-compute softmax probabilities after scaling logits with dynamic temperature
-    double max_l_double = candidates->data[0].logit;
+    const double max_l_double = cur_p->data[0].logit;
+
     double cum_sum_double = 0.0;
-    for (size_t i = 0; i < candidates->size; ++i) {
-        double p = exp(candidates->data[i].logit - max_l_double);
-        candidates->data[i].p = p; // Store the scaled probability
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        double p = exp(cur_p->data[i].logit - max_l_double);
+        cur_p->data[i].p = p; // Store the scaled probability
         cum_sum_double += p;
     }
-    for (size_t i = 0; i < candidates->size; ++i) {
-        candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
+
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
     }
 
 #ifdef DEBUG
     // Print the updated top 25 probabilities after temperature scaling
     LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
-    for (size_t i = 0; i < 25 && i < candidates->size; ++i) {
-        LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
+    for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
+        LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
     }
 #endif
-
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
 }
 
-void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    for (size_t i = 0; i < candidates->size; ++i) {
-        candidates->data[i].logit /= temp;
-    }
-
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        cur_p->data[i].logit /= temp;
     }
 }
 
-void llama_sample_repetition_penalties_impl(
-        struct llama_sampling * smpl,
-       llama_token_data_array * candidates,
-            const llama_token * last_tokens,
-                       size_t   penalty_last_n,
-                       float   penalty_repeat,
-                       float   penalty_freq,
-                       float   penalty_present) {
-    if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
-        return;
-    }
-
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    // Create a frequency map to count occurrences of each token in last_tokens
-    std::unordered_map<llama_token, int> token_count;
-    for (size_t i = 0; i < penalty_last_n; ++i) {
-        token_count[last_tokens[i]]++;
-    }
+static void llama_sampler_grammar_impl(llama_token_data_array * cur_p, const struct llama_grammar & grammar) {
+    llama_grammar_apply_impl(grammar, cur_p);
+}
 
-    // Apply frequency and presence penalties to the candidates
-    for (size_t i = 0; i < candidates->size; ++i) {
-        const auto token_iter = token_count.find(candidates->data[i].id);
+void llama_sampler_penalties_impl(
+       llama_token_data_array * cur_p,
+        const llama_token_cnt & token_count,
+                        float   penalty_repeat,
+                        float   penalty_freq,
+                        float   penalty_present) {
+    // Apply frequency and presence penalties to the cur_p
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        const auto token_iter = token_count.find(cur_p->data[i].id);
         if (token_iter == token_count.end()) {
             continue;
         }
@@ -465,171 +421,999 @@ void llama_sample_repetition_penalties_impl(
 
         // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
         // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
-        if (candidates->data[i].logit <= 0) {
-            candidates->data[i].logit *= penalty_repeat;
+        if (cur_p->data[i].logit <= 0) {
+            cur_p->data[i].logit *= penalty_repeat;
         } else {
-            candidates->data[i].logit /= penalty_repeat;
+            cur_p->data[i].logit /= penalty_repeat;
         }
 
-        candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
+        cur_p->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
+    }
+
+    cur_p->sorted = false;
+}
+
+// llama_sampler API
+
+const char * llama_sampler_name(const struct llama_sampler * smpl) {
+    if (!smpl->iface) {
+        return "(null)";
+    }
+
+    return smpl->iface->name(smpl);
+}
+
+void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
+    if (smpl->iface->accept) {
+        smpl->iface->accept(smpl, token);
+    }
+}
+
+void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
+    GGML_ASSERT(smpl->iface->apply);
+    smpl->iface->apply(smpl, cur_p);
+}
+
+void llama_sampler_reset(struct llama_sampler * smpl) {
+    if (smpl->iface->reset) {
+        smpl->iface->reset(smpl);
     }
+}
 
-    candidates->sorted = false;
+struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
+    if (smpl->iface->clone) {
+        return smpl->iface->clone(smpl);
+    }
 
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    if (smpl->ctx == nullptr) {
+        return new llama_sampler {
+            /* .iface = */ smpl->iface,
+            /* .ctx   = */ nullptr,
+        };
     }
+
+    GGML_ABORT("the sampler does not support cloning");
 }
 
-void llama_sample_apply_guidance_impl(
-        struct llama_sampling * smpl,
-                        float * logits,
-                        float * logits_guidance,
-                        float   scale) {
-    GGML_ASSERT(smpl);
+void llama_sampler_free(struct llama_sampler * smpl) {
+    if (smpl == nullptr) {
+        return;
+    }
+
+    if (smpl->iface->free) {
+        smpl->iface->free(smpl);
+    }
 
-    const auto t_start_sample_us = ggml_time_us();
-    const auto n_vocab = smpl->n_vocab;
+    delete smpl;
+}
 
-    llama_log_softmax(logits, n_vocab);
-    llama_log_softmax(logits_guidance, n_vocab);
+llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
+    const auto * logits = llama_get_logits_ith(ctx, idx);
 
-    for (int i = 0; i < n_vocab; ++i) {
-              auto & l = logits[i];
-        const auto & g = logits_guidance[i];
+    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
 
-        l = scale * (l - g) + g;
+    // TODO: do not allocate each time
+    std::vector<llama_token_data> cur(n_vocab);
+    for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+        cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
     }
 
-    smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+
+    llama_sampler_apply(smpl, &cur_p);
+
+    return cur_p.data[cur_p.selected].id;
 }
 
-llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
-    GGML_ASSERT(smpl);
+// sampler chain
+
+static struct llama_sampler_i llama_sampler_chain_i = {
+    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; },
+    /* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
+        auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+        time_meas tm(chain->t_sample_us, chain->params.no_perf);
+
+        for (auto * smpl : chain->samplers) {
+            llama_sampler_accept(smpl, token);
+        }
+
+        chain->n_sample++;
+    },
+    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+        auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+        time_meas tm(chain->t_sample_us, chain->params.no_perf);
+
+        for (auto * smpl : chain->samplers) {
+            llama_sampler_apply(smpl, cur_p);
+        }
+    },
+    /* .reset  = */ [](struct llama_sampler * smpl) {
+        auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+        for (auto * smpl : chain->samplers) {
+            llama_sampler_reset(smpl);
+        }
+
+        chain->t_sample_us = 0;
+        chain->n_sample    = 0;
+    },
+    /* .clone  = */ [](const struct llama_sampler * smpl) {
+        const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
 
-    const int32_t n_vocab = float(smpl->n_vocab);
+        auto * result = llama_sampler_chain_init(chain_src->params);
 
-    int64_t t_start_sample_us = ggml_time_us();
+        for (auto * smpl : chain_src->samplers) {
+            llama_sampler_chain_add(result, llama_sampler_clone(smpl));
+        }
+
+        return result;
+    },
+    /* .free   = */ [](struct llama_sampler * smpl) {
+        auto * chain = (llama_sampler_chain *) smpl->ctx;
+
+        for (auto * smpl : chain->samplers) {
+            llama_sampler_free(smpl);
+        }
 
-    llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
+        delete chain;
+    },
+};
+
+struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_chain_i,
+        /* .ctx   = */ new llama_sampler_chain {
+            /* .params      = */ params,
+            /* .samplers    = */ {},
+            /* .t_sample_us = */ 0,
+            /* .n_sample    = */ 0,
+        },
+    };
+}
+
+void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
+    auto * p = (llama_sampler_chain *) chain->ctx;
+    p->samplers.push_back(smpl);
+}
 
-    // Estimate s_hat using the most probable m tokens
-    float s_hat = 0.0;
-    float sum_ti_bi = 0.0;
-    float sum_ti_sq = 0.0;
-    for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
-        float t_i = logf(float(i + 2) / float(i + 1));
-        float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
-        sum_ti_bi += t_i * b_i;
-        sum_ti_sq += t_i * t_i;
+struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
+    const auto * p = (const llama_sampler_chain *) chain->ctx;
+
+    if (i < 0 || i >= (int32_t) p->samplers.size()) {
+        return nullptr;
     }
-    s_hat = sum_ti_bi / sum_ti_sq;
 
-    // Compute k from the estimated s_hat and target surprise value
-    float epsilon_hat = s_hat - 1;
-    float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
+    return p->samplers[i];
+}
+
+int llama_sampler_chain_n(const struct llama_sampler * chain) {
+    const auto * p = (const llama_sampler_chain *) chain->ctx;
 
-    // Sample the next word X using top-k sampling
-    llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
-    smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
-    llama_token X = llama_sample_token_impl(smpl, candidates);
-    t_start_sample_us = ggml_time_us();
+    return p->samplers.size();
+}
 
-    // Compute error as the difference between observed surprise and target surprise value
-    size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
-        return candidate.id == X;
-    }));
-    float observed_surprise = -log2f(candidates->data[X_idx].p);
-    float e = observed_surprise - tau;
+//
+// samplers
+//
 
-    // Update mu using the learning rate and error
-    *mu = *mu - eta * e;
+// greedy
 
-    smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
-    return X;
+static struct llama_sampler_i llama_sampler_greedy_i = {
+    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "greedy"; },
+    /* .accept = */ nullptr,
+    /* .apply  = */ [](struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
+        cur_p->selected = 0;
+        for (size_t i = 1; i < cur_p->size; ++i) {
+            if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
+                cur_p->selected = i;
+            }
+        }
+    },
+    /* .reset  = */ nullptr,
+    /* .clone  = */ nullptr,
+    /* .free   = */ nullptr,
+};
+
+struct llama_sampler * llama_sampler_init_greedy() {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_greedy_i,
+        /* .ctx   = */ nullptr,
+    };
 }
 
-llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
-    int64_t t_start_sample_us;
-    t_start_sample_us = ggml_time_us();
+// dist
 
-    llama_sample_softmax_impl(smpl, candidates);
+struct llama_sampler_dist {
+    const uint32_t seed;
 
-    // Truncate the words with surprise values greater than mu
-    candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
-        return -log2f(candidate.p) > *mu;
-    }));
+    std::mt19937 rng;
 
-    if (candidates->size == 0) {
-        candidates->size = 1;
-    }
+    std::vector<float> probs; // work array
+};
 
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
+static struct llama_sampler_i llama_sampler_dist_i = {
+    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "dist"; },
+    /* .accept = */ nullptr,
+    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+        auto * ctx = (llama_sampler_dist *) smpl->ctx;
+        cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
+    },
+    /* .reset  = */ nullptr,
+    /* .clone  = */ [](const struct llama_sampler * smpl) {
+        const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
+        auto * result = llama_sampler_init_dist(ctx->seed);
 
-    // Normalize the probabilities of the remaining words
-    llama_sample_softmax_impl(smpl, candidates);
+        // copy the state
+        {
+            auto * result_ctx = (llama_sampler_dist *) result->ctx;
 
-    // Sample the next word X from the remaining words
-    llama_token X = llama_sample_token_impl(smpl, candidates);
-    t_start_sample_us = ggml_time_us();
+            result_ctx->rng = ctx->rng;
+        }
 
-    // Compute error as the difference between observed surprise and target surprise value
-    size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
-        return candidate.id == X;
-    }));
-    float observed_surprise = -log2f(candidates->data[X_idx].p);
-    float e = observed_surprise - tau;
+        return result;
+    },
+    /* .free   = */ [](struct llama_sampler * smpl) {
+        delete (llama_sampler_dist *) smpl->ctx;
+    },
+};
+
+struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_dist_i,
+        /* .ctx   = */ new llama_sampler_dist {
+            /* .seed = */ seed,
+            /* .rng  = */ std::mt19937(seed),
+            /* .probs = */ {},
+        },
+    };
+}
 
-    // Update mu using the learning rate and error
-    *mu = *mu - eta * e;
+// softmax
+
+static struct llama_sampler_i llama_sampler_softmax_i = {
+    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "softmax"; },
+    /* .accept = */ nullptr,
+    /* .apply  = */ [](struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
+        llama_sampler_softmax_impl(cur_p);
+    },
+    /* .reset  = */ nullptr,
+    /* .clone  = */ nullptr,
+    /* .free   = */ nullptr,
+};
+
+struct llama_sampler * llama_sampler_init_softmax() {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_softmax_i,
+        /* .ctx   = */ nullptr,
+    };
+}
 
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-    return X;
+// top-k
+
+struct llama_sampler_top_k {
+    const int32_t k;
+};
+
+static struct llama_sampler_i llama_sampler_top_k_i = {
+    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "top-k"; },
+    /* .accept = */ nullptr,
+    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+        const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
+        llama_sampler_top_k_impl(cur_p, ctx->k);
+    },
+    /* .reset  = */ nullptr,
+    /* .clone  = */ [](const struct llama_sampler * smpl) {
+        const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
+        return llama_sampler_init_top_k(ctx->k);
+    },
+    /* .free   = */ [](struct llama_sampler * smpl) {
+        delete (llama_sampler_top_k *) smpl->ctx;
+    },
+};
+
+struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_top_k_i,
+        /* .ctx   = */ new llama_sampler_top_k {
+            /* .k = */ k,
+        },
+    };
 }
 
-llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
-    const int64_t t_start_sample_us = ggml_time_us();
+// top-p
+
+struct llama_sampler_top_p {
+    const float  p;
+    const size_t min_keep;
+};
+
+static struct llama_sampler_i llama_sampler_top_p_i = {
+    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "top-p"; },
+    /* .accept = */ nullptr,
+    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+        const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
+        llama_sampler_top_p_impl(cur_p, ctx->p, ctx->min_keep);
+    },
+    /* .reset  = */ nullptr,
+    /* .clone  = */ [](const struct llama_sampler * smpl) {
+        const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
+        return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
+    },
+    /* .free   = */ [](struct llama_sampler * smpl) {
+        delete (llama_sampler_top_p *) smpl->ctx;
+    },
+};
+
+struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_top_p_i,
+        /* .ctx   = */ new llama_sampler_top_p {
+            /* .p        = */ p,
+            /* .min_keep = */ min_keep,
+        },
+    };
+}
 
-    // Find max element
-    auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
-        return a.logit < b.logit;
-    });
+// min-p
+
+struct llama_sampler_min_p {
+    const float  p;
+    const size_t min_keep;
+};
+
+static struct llama_sampler_i llama_sampler_min_p_i = {
+    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "min-p"; },
+    /* .accept = */ nullptr,
+    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+        const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
+        llama_sampler_min_p_impl(cur_p, ctx->p, ctx->min_keep);
+    },
+    /* .reset  = */ nullptr,
+    /* .clone  = */ [](const struct llama_sampler * smpl) {
+        const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
+        return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
+    },
+    /* .free   = */ [](struct llama_sampler * smpl) {
+        delete (llama_sampler_min_p *) smpl->ctx;
+    },
+};
+
+struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_min_p_i,
+        /* .ctx   = */ new llama_sampler_min_p {
+            /* .p        = */ p,
+            /* .min_keep = */ min_keep,
+        },
+    };
+}
 
-    llama_token result = max_iter->id;
-    if (smpl) {
-        smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
-        smpl->n_sample++;
-    }
-    return result;
+// tail-free
+
+struct llama_sampler_tail_free {
+    const float  z;
+    const size_t min_keep;
+};
+
+static struct llama_sampler_i llama_sampler_tail_free_i = {
+    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "tail-free"; },
+    /* .accept = */ nullptr,
+    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+        const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
+        llama_sampler_tail_free_impl(cur_p, ctx->z, ctx->min_keep);
+    },
+    /* .reset  = */ nullptr,
+    /* .clone  = */ [](const struct llama_sampler * smpl) {
+        const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
+        return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
+    },
+    /* .free   = */ [](struct llama_sampler * smpl) {
+        delete (llama_sampler_tail_free *) smpl->ctx;
+    },
+};
+
+struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_tail_free_i,
+        /* .ctx   = */ new llama_sampler_tail_free {
+            /* .z        = */ z,
+            /*. min_keep = */ min_keep,
+        },
+    };
 }
 
-llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
-    GGML_ASSERT(smpl);
+// typical
+
+struct llama_sampler_typical {
+    const float  p;
+    const size_t min_keep;
+};
+
+static struct llama_sampler_i llama_sampler_typical_i = {
+    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "typical"; },
+    /* .accept = */ nullptr,
+    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+        const auto * ctx = (llama_sampler_typical *) smpl->ctx;
+        llama_sampler_typical_impl(cur_p, ctx->p, ctx->min_keep);
+    },
+    /* .reset  = */ nullptr,
+    /* .clone  = */ [](const struct llama_sampler * smpl) {
+        const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
+        return llama_sampler_init_typical(ctx->p, ctx->min_keep);
+    },
+    /* .free   = */ [](struct llama_sampler * smpl) {
+        delete (llama_sampler_typical *) smpl->ctx;
+    },
+};
+
+struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_typical_i,
+        /* .ctx   = */ new llama_sampler_typical {
+            /* .p        = */ p,
+            /* .min_keep = */ min_keep,
+        },
+    };
+}
+
+// temp
+
+struct llama_sampler_temp {
+    const float temp;
+};
+
+static struct llama_sampler_i llama_sampler_temp_i = {
+    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "temp"; },
+    /* .accept = */ nullptr,
+    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+        const auto * ctx = (llama_sampler_temp *) smpl->ctx;
+        llama_sampler_temp_impl(cur_p, ctx->temp);
+    },
+    /* .reset  = */ nullptr,
+    /* .clone  = */ [](const struct llama_sampler * smpl) {
+        const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
+        return llama_sampler_init_temp(ctx->temp);
+    },
+    /* .free   = */ [](struct llama_sampler * smpl) {
+        delete (llama_sampler_temp *) smpl->ctx;
+    },
+};
+
+struct llama_sampler * llama_sampler_init_temp(float temp) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_temp_i,
+        /* .ctx   = */ new llama_sampler_temp {
+            /*.temp = */ temp,
+        },
+    };
+}
 
-    const int64_t t_start_sample_us = ggml_time_us();
-    llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
+// temp-ext
+
+struct llama_sampler_temp_ext {
+    const float temp;
+    const float delta;
+    const float exponent;
+};
+
+static struct llama_sampler_i llama_sampler_temp_ext_i = {
+    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "temp-ext"; },
+    /* .accept = */ nullptr,
+    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+        const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
+        if (ctx->delta > 0) {
+            const float temp_min = std::max(0.0f, ctx->temp - ctx->delta);
+            const float temp_max = ctx->temp + ctx->delta;
+
+            llama_sampler_entropy_impl(cur_p, temp_min, temp_max, ctx->exponent);
+        } else {
+            llama_sampler_temp_impl(cur_p, ctx->temp);
+        }
+    },
+    /* .reset  = */ nullptr,
+    /* .clone  = */ [](const struct llama_sampler * smpl) {
+        const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
+        return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
+    },
+    /* .free   = */ [](struct llama_sampler * smpl) {
+        delete (llama_sampler_temp_ext *) smpl->ctx;
+    },
+};
+
+struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_temp_ext_i,
+        /* .ctx   = */ new llama_sampler_temp_ext {
+            /* .temp     = */ temp,
+            /* .delta    = */ delta,
+            /* .exponent = */ exponent,
+        },
+    };
+}
+
+// mirostat
+
+struct llama_sampler_mirostat {
+    const int32_t n_vocab;
+
+    const uint32_t seed;
+
+    const float tau;
+    const float eta;
+
+    const int32_t m;
+
+    float mu;
+
+    std::mt19937 rng;
+
+    std::vector<float> probs;
+};
+
+static struct llama_sampler_i llama_sampler_mirostat_i = {
+    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat"; },
+    /* .accept = */ nullptr,
+    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+        auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
+
+        llama_sampler_softmax_impl(cur_p);
+
+        // Estimate s_hat using the most probable m tokens
+        float s_hat = 0.0;
+        float sum_ti_bi = 0.0;
+        float sum_ti_sq = 0.0;
+        for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
+            float t_i = logf(float(i + 2) / float(i + 1));
+            float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
+            sum_ti_bi += t_i * b_i;
+            sum_ti_sq += t_i * t_i;
+        }
+        s_hat = sum_ti_bi / sum_ti_sq;
+
+        // Compute k from the estimated s_hat and target surprise value
+        float epsilon_hat = s_hat - 1;
+        float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
+
+        llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
+        llama_sampler_softmax_impl(cur_p);
+
+        const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
+
+        cur_p->selected = idx;
+
+        float observed_surprise = -log2f(cur_p->data[idx].p);
+        float e = observed_surprise - ctx->tau;
+
+        // Update mu using the learning rate and error
+        ctx->mu = ctx->mu - ctx->eta * e;
+    },
+    /* .reset  = */ [](struct llama_sampler * smpl) {
+        auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
+        ctx->mu = 2.0f*ctx->tau;
+        ctx->rng = std::mt19937(ctx->seed);
+    },
+    /* .clone  = */ [](const struct llama_sampler * smpl) {
+        const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
+        auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
+
+        // copy the state
+        {
+            auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
+
+            result_ctx->mu  = ctx->mu;
+            result_ctx->rng = ctx->rng;
+        }
+
+        return result;
+    },
+    /* .free   = */ [](struct llama_sampler * smpl) {
+        delete (llama_sampler_mirostat *) smpl->ctx;
+    },
+};
+
+struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_mirostat_i,
+        /* .ctx   = */ new llama_sampler_mirostat {
+            /* .n_vocab = */ n_vocab,
+            /* .seed    = */ seed,
+            /* .tau     = */ tau,
+            /* .eta     = */ eta,
+            /* .m       = */ m,
+            /* .mu      = */ 2.0f*tau,
+            /* .rng     = */ std::mt19937(seed),
+            /* .probs   = */ {},
+        },
+    };
+}
+
+// mirostat v2
+
+struct llama_sampler_mirostat_v2 {
+    const uint32_t seed;
+
+    const float tau;
+    const float eta;
+
+    float mu;
+
+    std::mt19937 rng;
 
     std::vector<float> probs;
-    probs.reserve(candidates->size);
-    for (size_t i = 0; i < candidates->size; ++i) {
-        probs.push_back(candidates->data[i].p);
+};
+
+static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
+    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat-v2"; },
+    /* .accept = */ nullptr,
+    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+        auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
+
+        llama_sampler_softmax_impl(cur_p);
+
+        // Truncate the words with surprise values greater than mu
+        cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
+            return -log2f(candidate.p) > ctx->mu;
+        }));
+
+        if (cur_p->size == 0) {
+            cur_p->size = 1;
+        }
+
+        // Normalize the probabilities of the remaining words
+        llama_sampler_softmax_impl(cur_p);
+
+        const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs);
+
+        cur_p->selected = idx;
+
+        float observed_surprise = -log2f(cur_p->data[idx].p);
+        float e = observed_surprise - ctx->tau;
+
+        // Update mu using the learning rate and error
+        ctx->mu = ctx->mu - ctx->eta * e;
+    },
+    /* .reset  = */ [](struct llama_sampler * smpl) {
+        auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
+        ctx->mu = 2.0f*ctx->tau;
+        ctx->rng = std::mt19937(ctx->seed);
+    },
+    /* .clone  = */ [](const struct llama_sampler * smpl) {
+        const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
+
+        auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
+
+        // copy the state
+        {
+            auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
+
+            result_ctx->mu  = ctx->mu;
+            result_ctx->rng = ctx->rng;
+        }
+
+        return result;
+    },
+    /* .free   = */ [](struct llama_sampler * smpl) {
+        delete (llama_sampler_mirostat_v2 *) smpl->ctx;
+    },
+};
+
+struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_mirostat_v2_i,
+        /* .ctx   = */ new llama_sampler_mirostat_v2 {
+            /* .seed  = */ seed,
+            /* .tau   = */ tau,
+            /* .eta   = */ eta,
+            /* .mu    = */ 2.0f*tau,
+            /* .rng   = */ std::mt19937(seed),
+            /* .probs = */ {},
+        },
+    };
+}
+
+// grammar
+
+struct llama_sampler_grammar {
+    const struct llama_vocab * vocab;
+
+    std::string grammar_str;
+    std::string grammar_root;
+
+    struct llama_grammar * grammar;
+};
+
+static struct llama_sampler_i llama_sampler_grammar_i = {
+    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "grammar"; },
+    /* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
+        const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
+        if (ctx->grammar) {
+            llama_grammar_accept_impl(*ctx->grammar, token);
+        }
+    },
+    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+        const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
+        if (ctx->grammar) {
+            llama_sampler_grammar_impl(cur_p, *ctx->grammar);
+        }
+    },
+    /* .reset  = */ [](struct llama_sampler * smpl) {
+        auto * ctx = (llama_sampler_grammar *) smpl->ctx;
+        if (!ctx->grammar) {
+            return;
+        }
+
+        auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
+
+        llama_grammar_free_impl(ctx->grammar);
+        ctx->grammar = grammar_new;
+    },
+    /* .clone  = */ [](const struct llama_sampler * smpl) {
+        const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
+
+        auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
+
+        // copy the state
+        {
+            auto * result_ctx = (llama_sampler_grammar *) result->ctx;
+
+            if (ctx->grammar) {
+                result_ctx->grammar_str  = ctx->grammar_str;
+                result_ctx->grammar_root = ctx->grammar_root;
+
+                result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
+            }
+        }
+
+        return result;
+    },
+    /* .free   = */ [](struct llama_sampler * smpl) {
+        const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
+
+        if (ctx->grammar) {
+            llama_grammar_free_impl(ctx->grammar);
+        }
+
+        delete ctx;
+    },
+};
+
+struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
+    auto * ctx = new llama_sampler_grammar;
+
+    if (grammar_str != nullptr && grammar_str[0] != '\0') {
+        *ctx = {
+            /* .vocab        = */ &vocab,
+            /* .grammar_str  = */ grammar_str,
+            /* .grammar_root = */ grammar_root,
+            /* .grammar      = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
+        };
+    } else {
+        *ctx = {
+            /* .vocab        = */ &vocab,
+            /* .grammar_str  = */ {},
+            /* .grammar_root = */ {},
+            /* .grammar      = */ nullptr,
+        };
     }
 
-    std::discrete_distribution<> dist(probs.begin(), probs.end());
-    int idx = dist(rng);
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_grammar_i,
+        /* .ctx   = */ ctx,
+    };
+}
 
-    llama_token result = candidates->data[idx].id;
+// penalties
+
+struct llama_sampler_penalties {
+    const int32_t     n_vocab;
+    const llama_token special_eos_id;
+    const llama_token linefeed_id;
+
+    const int32_t penalty_last_n;
+    const float   penalty_repeat;
+    const float   penalty_freq;
+    const float   penalty_present;
+
+    const bool    penalize_nl;
+    const bool    ignore_eos;
+
+    ring_buffer<llama_token> prev;
+};
+
+static struct llama_sampler_i llama_sampler_penalties_i = {
+    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "penalties"; },
+    /* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
+        auto * ctx = (llama_sampler_penalties *) smpl->ctx;
+        ctx->prev.push_back(token);
+    },
+    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+        auto * ctx = (llama_sampler_penalties *) smpl->ctx;
+
+        if (ctx->ignore_eos) {
+            assert(ctx->special_eos_id >= 0);
+
+            // optimistically check if the candidates are not yet sorted/shuffled/truncated
+            if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
+                cur_p->data[ctx->special_eos_id].logit = -INFINITY;
+            } else {
+                // else, search for the special EOS token
+                for (size_t i = 0; i < cur_p->size; ++i) {
+                    if (cur_p->data[i].id == ctx->special_eos_id) {
+                        cur_p->data[i].logit = -INFINITY;
+                        break;
+                    }
+                }
+            }
+        }
 
-    smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
-    smpl->n_sample++;
+        if ((ctx->penalty_last_n == 0) ||
+            (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
+            return;
+        }
 
-    return result;
+        bool nl_found = false;
+        size_t nl_idx = 0;
+        float nl_logit = -INFINITY;
+        if (!ctx->penalize_nl) {
+            assert(ctx->linefeed_id >= 0);
+
+            // optimistically check if the candidates are not yet sorted/shuffled/truncated
+            if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
+                nl_found = true;
+                nl_idx = ctx->linefeed_id;
+                nl_logit = cur_p->data[ctx->linefeed_id].logit;
+            } else {
+                // else, search for the linefeed token
+                for (size_t i = 0; i < cur_p->size; ++i) {
+                    if (cur_p->data[i].id == ctx->linefeed_id) {
+                        nl_found = true;
+                        nl_idx = i;
+                        nl_logit = cur_p->data[i].logit;
+                        break;
+                    }
+                }
+            }
+        }
+
+        // Create a frequency map to count occurrences of each token in last_tokens
+        // TODO: optimize this by maintaining the token count in the sampler context
+        llama_token_cnt token_count;
+        for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
+            token_count[ctx->prev.rat(i)]++;
+        }
+
+        llama_sampler_penalties_impl(cur_p, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present);
+
+        if (!ctx->penalize_nl && nl_found) {
+            // restore the logit of the newline token if it was penalized
+            cur_p->data[nl_idx].logit = nl_logit;
+        }
+    },
+    /* .reset  = */ [](struct llama_sampler * smpl) {
+        auto * ctx = (llama_sampler_penalties *) smpl->ctx;
+        ctx->prev.clear();
+    },
+    /* .clone  = */ [](const struct llama_sampler * smpl) {
+        const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
+        auto * result = llama_sampler_init_penalties(
+                ctx->n_vocab,
+                ctx->special_eos_id,
+                ctx->linefeed_id,
+                ctx->penalty_last_n,
+                ctx->penalty_repeat,
+                ctx->penalty_freq,
+                ctx->penalty_present,
+                ctx->penalize_nl,
+                ctx->ignore_eos);
+
+        // copy the state
+        {
+            auto * result_ctx = (llama_sampler_penalties *) result->ctx;
+
+            result_ctx->prev = ctx->prev;
+        }
+
+        return result;
+    },
+    /* .free   = */ [](struct llama_sampler * smpl) {
+        delete (llama_sampler_penalties *) smpl->ctx;
+    },
+};
+
+struct llama_sampler * llama_sampler_init_penalties(
+        int32_t n_vocab,
+        llama_token special_eos_id,
+        llama_token linefeed_id,
+        int32_t penalty_last_n,
+        float penalty_repeat,
+        float penalty_freq,
+        float penalty_present,
+        bool penalize_nl,
+        bool ignore_eos) {
+    if (linefeed_id == LLAMA_TOKEN_NULL) {
+        penalize_nl = false;
+    }
+
+    if (special_eos_id == LLAMA_TOKEN_NULL) {
+        ignore_eos = true;
+    }
+
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_penalties_i,
+        /* .ctx   = */ new llama_sampler_penalties {
+            /* .n_vocab         = */ n_vocab,
+            /* .special_eos_id  = */ special_eos_id,
+            /* .linefeed_id     = */ linefeed_id,
+            /* .penalty_last_n  = */ penalty_last_n,
+            /* .penalty_repeat  = */ penalty_repeat,
+            /* .penalty_freq    = */ penalty_freq,
+            /* .penalty_present = */ penalty_present,
+            /* .penalize_nl     = */ penalize_nl,
+            /* .ignore_eos      = */ ignore_eos,
+            /* .prev            = */ ring_buffer<llama_token>(penalty_last_n),
+        },
+    };
 }
 
-llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
-    return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
+// logit-bias
+
+struct llama_sampler_logit_bias {
+    const int32_t n_vocab;
+
+    const std::vector<llama_logit_bias> logit_bias;
+
+    std::vector<llama_logit_bias> to_search;
+};
+
+static struct llama_sampler_i llama_sampler_logit_bias_i = {
+    /* .name   = */ [](const struct llama_sampler * /*smpl*/) { return "logit-bias"; },
+    /* .accept = */ nullptr,
+    /* .apply  = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+        auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
+
+        ctx->to_search.clear();
+
+        // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
+        for (const auto & lb : ctx->logit_bias) {
+            if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
+                cur_p->data[lb.token].logit += lb.bias;
+            } else {
+                ctx->to_search.push_back(lb);
+            }
+        }
+
+        // search for the remaining candidates that were not found in the previous step
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            for (const auto & lb : ctx->to_search) {
+                if (cur_p->data[i].id == lb.token) {
+                    cur_p->data[i].logit += lb.bias;
+                    break;
+                }
+            }
+        }
+    },
+    /* .reset  = */ nullptr,
+    /* .clone  = */ [](const struct llama_sampler * smpl) {
+        const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
+        return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
+    },
+    /* .free   = */ [](struct llama_sampler * smpl) {
+        delete (llama_sampler_logit_bias *) smpl->ctx;
+    },
+};
+
+struct llama_sampler * llama_sampler_init_logit_bias(
+                         int32_t   n_vocab,
+                         int32_t   n_logit_bias,
+          const llama_logit_bias * logit_bias) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_logit_bias_i,
+        /* .ctx   = */ new llama_sampler_logit_bias {
+            /* .n_vocab    = */ n_vocab,
+            /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
+            /* .to_search  = */ {},
+        },
+    };
 }
index f7f8e3ef706bc8d3dde7db5953811dd72fcad105..137c0025ce0d89ea0aa765ae4f3b57604907062b 100644 (file)
@@ -1,56 +1,39 @@
 #pragma once
 
-#include "llama-impl.h"
+// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
 
-struct llama_sampling {
-    llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
+#include "llama-grammar.h"
 
-    std::mt19937 rng;
+#include <unordered_map>
 
-    int32_t n_vocab = 0;
+struct llama_vocab;
+struct llama_grammar;
 
-    mutable int64_t t_sample_us = 0;
-    mutable int32_t n_sample = 0;
+// sampler chain
 
-    void reset_timings() const {
-        t_sample_us = 0;
-        n_sample = 0;
-    }
+struct llama_sampler_chain {
+    llama_sampler_chain_params params;
+
+    std::vector<struct llama_sampler *> samplers;
+
+    // timing
+
+    mutable int64_t t_sample_us;
+
+    mutable int32_t n_sample;
 };
 
-//
-// internal API
-//
-
-void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed);
-
-void llama_sample_softmax_impl  (struct llama_sampling * smpl, llama_token_data_array * candidates);
-void llama_sample_top_k_impl    (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
-void llama_sample_top_p_impl    (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
-void llama_sample_min_p_impl    (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
-void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
-void llama_sample_typical_impl  (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
-void llama_sample_entropy_impl  (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
-void llama_sample_temp_impl     (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
-
-void llama_sample_repetition_penalties_impl(
-        struct llama_sampling * smpl,
-       llama_token_data_array * candidates,
-            const llama_token * last_tokens,
-                       size_t   penalty_last_n,
+using llama_token_cnt = std::unordered_map<llama_token, int>;
+
+// TODO: tmp exposed until test-sampling is fixed
+void llama_sampler_penalties_impl(
+       llama_token_data_array * cur_p,
+        const llama_token_cnt & token_count,
                         float   penalty_repeat,
                         float   penalty_freq,
                         float   penalty_present);
 
-void llama_sample_apply_guidance_impl(
-        struct llama_sampling * smpl,
-                        float * logits,
-                        float * logits_guidance,
-                        float   scale);
-
-llama_token llama_sample_token_mirostat_impl   (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
-llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
-llama_token llama_sample_token_greedy_impl     (struct llama_sampling * smpl, llama_token_data_array * candidates);
-llama_token llama_sample_token_with_rng_impl   (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
-llama_token llama_sample_token_impl            (struct llama_sampling * smpl, llama_token_data_array * candidates);
-
+struct llama_sampler * llama_sampler_init_grammar_impl(
+        const struct llama_vocab & vocab,
+                      const char * grammar_str,
+                      const char * grammar_root);
index 6e8f30be43ba1cb2f93cd07d8189546998de49d1..dc4b5f12f7860030c2fc2a4f25e974597c0075ba 100644 (file)
@@ -18,6 +18,8 @@ struct llama_vocab {
         tattr attr;
     };
 
+    uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
+
     enum llama_vocab_type     type     = LLAMA_VOCAB_TYPE_SPM;
     enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
 
@@ -62,8 +64,6 @@ struct llama_vocab {
     int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
 };
 
-const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx);
-
 //
 // internal API
 //
@@ -76,6 +76,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(
         bool add_special,
         bool parse_special = false);
 
+// TODO: move the API below as member functions of llama_vocab
 llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
 
 const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
index 1a78112a3a84dd9169332e1b2439335988ef49af..6bbaf9fc9bae7f8bd93a89327f734e662dd78059 100644 (file)
@@ -1,6 +1,5 @@
 #include "llama-impl.h"
 #include "llama-vocab.h"
-#include "llama-grammar.h"
 #include "llama-sampling.h"
 
 #include "unicode.h"
@@ -3179,7 +3178,6 @@ struct llama_sbatch {
 struct llama_context {
     llama_context(const llama_model & model)
         : model(model)
-        , sampling(llama_n_vocab(&model))
         , t_start_us(model.t_start_us)
         , t_load_us(model.t_load_us) {}
 
@@ -3196,7 +3194,6 @@ struct llama_context {
     const struct llama_model & model;
 
     struct llama_cparams        cparams;
-    struct llama_sampling       sampling;
     struct llama_sbatch         sbatch;
     struct llama_kv_cache       kv_self;
     struct llama_control_vector cvec;
@@ -3217,16 +3214,16 @@ struct llama_context {
 
     bool has_evaluated_once = false;
 
-    int64_t t_start_us;
-    int64_t t_load_us;
-    int64_t t_p_eval_us = 0;
-    int64_t t_eval_us   = 0;
+    mutable int64_t t_start_us;
+    mutable int64_t t_load_us;
+    mutable int64_t t_p_eval_us = 0;
+    mutable int64_t t_eval_us   = 0;
 
-    int64_t t_compute_start_us = 0;
-    int64_t n_queued_tokens = 0;
+    mutable int64_t t_compute_start_us = 0;
+    mutable int64_t n_queued_tokens = 0;
 
-    int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
-    int32_t n_eval   = 0; // number of eval calls
+    mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
+    mutable int32_t n_eval   = 0; // number of eval calls
 
     // host buffer for the model output (logits and embeddings)
     ggml_backend_buffer_t buf_output = nullptr;
@@ -6251,6 +6248,7 @@ static void llm_load_vocab(
 
     const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
 
+    vocab.n_vocab = n_vocab;
     vocab.id_to_token.resize(n_vocab);
 
     for (uint32_t i = 0; i < n_vocab; i++) {
@@ -17892,7 +17890,6 @@ struct llama_model_params llama_model_default_params() {
 
 struct llama_context_params llama_context_default_params() {
     struct llama_context_params result = {
-        /*.seed                        =*/ LLAMA_DEFAULT_SEED,
         /*.n_ctx                       =*/ 512,
         /*.n_batch                     =*/ 2048,
         /*.n_ubatch                    =*/ 512,
@@ -17925,6 +17922,14 @@ struct llama_context_params llama_context_default_params() {
     return result;
 }
 
+struct llama_sampler_chain_params llama_sampler_chain_default_params() {
+    struct llama_sampler_chain_params result = {
+        /*.no_perf                     =*/ true,
+    };
+
+    return result;
+}
+
 struct llama_model_quantize_params llama_model_quantize_default_params() {
     struct llama_model_quantize_params result = {
         /*.nthread                     =*/ 0,
@@ -18178,10 +18183,6 @@ struct llama_context * llama_new_context_with_model(
         cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
     }
 
-    if (params.seed == LLAMA_DEFAULT_SEED) {
-        params.seed = time(NULL);
-    }
-
     LLAMA_LOG_INFO("%s: n_ctx      = %u\n",     __func__, cparams.n_ctx);
     LLAMA_LOG_INFO("%s: n_batch    = %u\n",     __func__, cparams.n_batch);
     LLAMA_LOG_INFO("%s: n_ubatch   = %u\n",     __func__, cparams.n_ubatch);
@@ -18192,10 +18193,10 @@ struct llama_context * llama_new_context_with_model(
     ctx->abort_callback      = params.abort_callback;
     ctx->abort_callback_data = params.abort_callback_data;
 
-    ctx->sampling.rng = std::mt19937(params.seed);
-    ctx->logits_all   = params.logits_all;
+    ctx->logits_all = params.logits_all;
+
     // build worst-case graph for encoder if a model contains encoder
-    ctx->is_encoding  = llama_model_has_encoder(model);
+    ctx->is_encoding = llama_model_has_encoder(model);
 
     uint32_t kv_size = cparams.n_ctx;
     ggml_type type_k = params.type_k;
@@ -18473,14 +18474,6 @@ void llama_free(struct llama_context * ctx) {
     delete ctx;
 }
 
-const struct llama_model * llama_get_model(const struct llama_context * ctx) {
-    return &ctx->model;
-}
-
-const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx) {
-    return &ctx->model.vocab;
-}
-
 uint32_t llama_n_ctx(const struct llama_context * ctx) {
     return ctx->cparams.n_ctx;
 }
@@ -18501,6 +18494,30 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
     return model->vocab.type;
 }
 
+int32_t llama_n_vocab(const struct llama_model * model) {
+    return model->hparams.n_vocab;
+}
+
+int32_t llama_n_ctx_train(const struct llama_model * model) {
+    return model->hparams.n_ctx_train;
+}
+
+int32_t llama_n_embd(const struct llama_model * model) {
+    return model->hparams.n_embd;
+}
+
+int32_t llama_n_layer(const struct llama_model * model) {
+    return model->hparams.n_layer;
+}
+
+const struct llama_model * llama_get_model(const struct llama_context * ctx) {
+    return &ctx->model;
+}
+
+enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
+    return ctx->cparams.pooling_type;
+}
+
 enum llama_rope_type llama_rope_type(const struct llama_model * model) {
     switch (model->arch) {
         // these models do not use RoPE
@@ -18564,26 +18581,6 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
     return LLAMA_ROPE_TYPE_NONE;
 }
 
-enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
-    return ctx->cparams.pooling_type;
-}
-
-int32_t llama_n_vocab(const struct llama_model * model) {
-    return model->hparams.n_vocab;
-}
-
-int32_t llama_n_ctx_train(const struct llama_model * model) {
-    return model->hparams.n_ctx_train;
-}
-
-int32_t llama_n_embd(const struct llama_model * model) {
-    return model->hparams.n_embd;
-}
-
-int32_t llama_n_layer(const struct llama_model * model) {
-    return model->hparams.n_layer;
-}
-
 float llama_rope_freq_scale_train(const struct llama_model * model) {
     return model->hparams.rope_freq_scale_train;
 }
@@ -19000,14 +18997,14 @@ struct llama_data_write {
         // TODO: add more model-specific info which should prevent loading the session file if not identical
     }
 
-    void write_rng(const std::mt19937 & rng) {
-        std::ostringstream rng_ss;
-        rng_ss << rng;
+    //void write_rng(const std::mt19937 & rng) {
+    //    std::ostringstream rng_ss;
+    //    rng_ss << rng;
 
-        const std::string & rng_str = rng_ss.str();
+    //    const std::string & rng_str = rng_ss.str();
 
-        write_string(rng_str);
-    }
+    //    write_string(rng_str);
+    //}
 
     void write_output_ids(struct llama_context * ctx) {
         llama_output_reorder(ctx);
@@ -19227,17 +19224,17 @@ struct llama_data_read {
         // TODO: add more info which needs to be identical but which is not verified otherwise
     }
 
-    void read_rng(std::mt19937 & rng) {
-        std::string rng_str;
-        read_string(rng_str);
+    //void read_rng(std::mt19937 & rng) {
+    //    std::string rng_str;
+    //    read_string(rng_str);
 
-        std::istringstream rng_ss(rng_str);
-        rng_ss >> rng;
+    //    std::istringstream rng_ss(rng_str);
+    //    rng_ss >> rng;
 
-        if (rng_ss.fail()) {
-            throw std::runtime_error("failed to load RNG state");
-        }
-    }
+    //    if (rng_ss.fail()) {
+    //        throw std::runtime_error("failed to load RNG state");
+    //    }
+    //}
 
     void read_output_ids(struct llama_context * ctx) {
         std::vector<int32_t> output_pos;
@@ -19667,8 +19664,6 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da
 
     data_ctx.write_model_info(ctx);
 
-    data_ctx.write_rng(ctx->sampling.rng);
-
     // copy outputs
     data_ctx.write_output_ids(ctx);
     data_ctx.write_logits(ctx);
@@ -19706,9 +19701,6 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da
 
     data_ctx.read_model_info(ctx);
 
-    // set rng
-    data_ctx.read_rng(ctx->sampling.rng);
-
     // set outputs
     data_ctx.read_output_ids(ctx);
     data_ctx.read_logits(ctx);
@@ -20111,8 +20103,9 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
         LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
 #ifndef NDEBUG
         GGML_ABORT("fatal error");
-#endif
+#else
         return nullptr;
+#endif
     }
 }
 
@@ -20160,8 +20153,9 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
         LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
 #ifndef NDEBUG
         GGML_ABORT("fatal error");
-#endif
+#else
         return nullptr;
+#endif
     }
 }
 
@@ -20595,124 +20589,18 @@ int32_t llama_chat_apply_template(
 }
 
 //
-// grammar
+// sampling
 //
 
-struct llama_grammar * llama_grammar_init(
-        const llama_grammar_element ** rules,
-        size_t    n_rules,
-        size_t    start_rule_index) {
-    return llama_grammar_init_impl(rules, n_rules, start_rule_index);
-}
-
-void llama_grammar_free(struct llama_grammar * grammar) {
-    llama_grammar_free_impl(grammar);
-}
-
-struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
-    return llama_grammar_copy_impl(grammar);
-}
-
-void llama_grammar_sample(
-      const struct llama_grammar * grammar,
-      const struct llama_context * ctx,
-          llama_token_data_array * candidates) {
-    llama_grammar_sample_impl(grammar, &ctx->model.vocab, &ctx->sampling, candidates);
-}
-
-void llama_sample_grammar(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-      const struct llama_grammar * grammar) {
-    llama_grammar_sample(grammar, ctx, candidates);
-}
-
-void llama_grammar_accept_token(
-            struct llama_grammar * grammar,
-            struct llama_context * ctx,
-                     llama_token   token) {
-    llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token);
+// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
+struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
+    return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
 }
 
 //
-// sampling
+// model split
 //
 
-void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
-    llama_set_rng_seed_impl(&ctx->sampling, seed);
-}
-
-void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
-    llama_sample_softmax_impl(ctx ? &ctx->sampling : nullptr, candidates);
-}
-
-void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
-    llama_sample_top_k_impl(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep);
-}
-
-void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
-    llama_sample_top_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
-}
-
-void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
-    llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
-}
-
-void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
-    llama_sample_tail_free_impl(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep);
-}
-
-void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
-    llama_sample_typical_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
-}
-
-void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
-    llama_sample_entropy_impl(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val);
-}
-
-void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
-    llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp);
-}
-
-void llama_sample_repetition_penalties(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-               const llama_token * last_tokens,
-                          size_t   penalty_last_n,
-                           float   penalty_repeat,
-                           float   penalty_freq,
-                           float   penalty_present) {
-    llama_sample_repetition_penalties_impl(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
-}
-
-void llama_sample_apply_guidance(
-          struct llama_context * ctx,
-                         float * logits,
-                         float * logits_guidance,
-                         float   scale) {
-    llama_sample_apply_guidance_impl(&ctx->sampling, logits, logits_guidance, scale);
-}
-
-llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
-    return llama_sample_token_mirostat_impl(&ctx->sampling, candidates, tau, eta, m, mu);
-}
-
-llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
-    return llama_sample_token_mirostat_v2_impl(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu);
-}
-
-llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
-    return llama_sample_token_greedy_impl(ctx ? &ctx->sampling : nullptr, candidates);
-}
-
-llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
-    return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, rng);
-}
-
-llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
-    return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng);
-}
-
 int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
     static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
     if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
@@ -20737,45 +20625,6 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int
     return 0;
 }
 
-struct llama_timings llama_get_timings(struct llama_context * ctx) {
-    struct llama_timings result = {
-        /*.t_start_ms  =*/ 1e-3 * ctx->t_start_us,
-        /*.t_end_ms    =*/ 1.00 * ggml_time_ms(),
-        /*.t_load_ms   =*/ 1e-3 * ctx->t_load_us,
-        /*.t_sample_ms =*/ 1e-3 * ctx->sampling.t_sample_us,
-        /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us,
-        /*.t_eval_ms   =*/ 1e-3 * ctx->t_eval_us,
-
-        /*.n_sample =*/ std::max(1, ctx->sampling.n_sample),
-        /*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
-        /*.n_eval   =*/ std::max(1, ctx->n_eval),
-    };
-
-    return result;
-}
-
-void llama_print_timings(struct llama_context * ctx) {
-    const llama_timings timings = llama_get_timings(ctx);
-
-    LLAMA_LOG_INFO("\n");
-    LLAMA_LOG_INFO("%s:        load time = %10.2f ms\n", __func__, timings.t_load_ms);
-    LLAMA_LOG_INFO("%s:      sample time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
-            __func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample);
-    LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
-            __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
-    LLAMA_LOG_INFO("%s:        eval time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
-            __func__, timings.t_eval_ms, timings.n_eval, timings.t_eval_ms / timings.n_eval, 1e3 / timings.t_eval_ms * timings.n_eval);
-    LLAMA_LOG_INFO("%s:       total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval));
-}
-
-void llama_reset_timings(struct llama_context * ctx) {
-    ctx->t_start_us  = ggml_time_us();
-    ctx->t_eval_us   = ctx->n_eval   = 0;
-    ctx->t_p_eval_us = ctx->n_p_eval = 0;
-
-    ctx->sampling.reset_timings();
-}
-
 const char * llama_print_system_info(void) {
     static std::string s;
 
@@ -20804,7 +20653,68 @@ const char * llama_print_system_info(void) {
     return s.c_str();
 }
 
-void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
+void llama_perf_print(const void * ctx, enum llama_perf_type type) {
+    switch (type) {
+        case LLAMA_PERF_TYPE_CONTEXT:
+            {
+                const auto * p = (const struct llama_context *) ctx;
+
+                const double t_start_ms   = 1e-3 * p->t_start_us;
+                const double t_end_ms     = 1.00 * ggml_time_ms();
+                const double t_load_ms    = 1e-3 * p->t_load_us;
+                const double t_p_eval_ms  = 1e-3 * p->t_p_eval_us;
+                const double t_eval_ms    = 1e-3 * p->t_eval_us;
+
+                const int32_t n_p_eval  = std::max(0, p->n_p_eval);
+                const int32_t n_eval    = std::max(1, p->n_eval);
+
+                LLAMA_LOG_INFO("%s:        load time = %10.2f ms\n", __func__, t_load_ms);
+                LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
+                        __func__, t_p_eval_ms, n_p_eval, t_p_eval_ms / n_p_eval, 1e3 / t_p_eval_ms * n_p_eval);
+                LLAMA_LOG_INFO("%s:        eval time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
+                        __func__, t_eval_ms, n_eval, t_eval_ms / n_eval, 1e3 / t_eval_ms * n_eval);
+                LLAMA_LOG_INFO("%s:       total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - t_start_ms), (n_p_eval + n_eval));
+            } break;
+        case LLAMA_PERF_TYPE_SAMPLER_CHAIN:
+            {
+                const auto * smpl = (const struct llama_sampler *) ctx;
+                const auto * p = (const struct llama_sampler_chain *) smpl->ctx;
+
+                const double t_sampler_ms = 1e-3 * p->t_sample_us;
+
+                const int32_t n_sampler = std::max(0, p->n_sample);
+
+                LLAMA_LOG_INFO("%s:    sampling time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
+                        __func__, t_sampler_ms, n_sampler, t_sampler_ms / n_sampler, 1e3 / t_sampler_ms * n_sampler);
+            } break;
+        default:
+            GGML_ABORT("invalid perf type");
+    }
+}
+
+void llama_perf_reset(void * ctx, enum llama_perf_type type) {
+    switch (type) {
+        case LLAMA_PERF_TYPE_CONTEXT:
+            {
+                auto * p = (struct llama_context *) ctx;
+
+                p->t_start_us  = ggml_time_us();
+                p->t_eval_us   = p->n_eval = 0;
+                p->t_p_eval_us = p->n_p_eval = 0;
+            } break;
+        case LLAMA_PERF_TYPE_SAMPLER_CHAIN:
+            {
+                auto * smpl = (struct llama_sampler *) ctx;
+                auto * p = (struct llama_sampler_chain *) smpl->ctx;
+
+                p->t_sample_us = p->n_sample = 0;
+            } break;
+        default:
+            GGML_ABORT("invalid perf type");
+    }
+}
+
+void llama_perf_dump_yaml(FILE * stream, const llama_context * ctx) {
     fprintf(stream, "\n");
     fprintf(stream, "###########\n");
     fprintf(stream, "# Timings #\n");
@@ -20815,21 +20725,15 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
             1.0e-3 * ctx->t_eval_us / ctx->n_eval);
     fprintf(stream, "mst_p_eval: %.2f  # ms / token during prompt processing\n",
             1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval);
-    fprintf(stream, "mst_sample: %.2f  # ms / token during sampling\n",
-            1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample);
     fprintf(stream, "n_eval: %d  # number of tokens generated (excluding the first one)\n", ctx->n_eval);
     fprintf(stream, "n_p_eval: %d  # number of tokens processed in batches at the beginning\n", ctx->n_p_eval);
-    fprintf(stream, "n_sample: %d  # number of sampled tokens\n", ctx->sampling.n_sample);
     fprintf(stream, "t_eval_us: %" PRId64 "  # total microseconds spent generating tokens\n", ctx->t_eval_us);
     fprintf(stream, "t_load_us: %" PRId64 "  # total microseconds spent loading the model\n", ctx->t_load_us);
     fprintf(stream, "t_p_eval_us: %" PRId64 "  # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
-    fprintf(stream, "t_sample_us: %" PRId64 "  # total microseconds spent sampling\n", ctx->sampling.t_sample_us);
     fprintf(stream, "ts_eval: %.2f  # tokens / second during generation\n",
             1.0e6 * ctx->n_eval / ctx->t_eval_us);
     fprintf(stream, "ts_p_eval: %.2f  # tokens / second during prompt processing\n",
             1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us);
-    fprintf(stream, "ts_sample: %.2f  # tokens / second during sampling\n",
-            1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us);
 }
 
 // For internal test use
index 9c4e7d18e37b23e5081e81b8c8a9a99a9fa0d81b..5cc0cdb04751ff776dc1a727bfc1edbd1c6edc70 100644 (file)
@@ -2,33 +2,18 @@
 #undef NDEBUG
 #endif
 
-#define LLAMA_API_INTERNAL
-
-#include "ggml.h"
-#include "llama.h"
-#include "grammar-parser.h"
-#include "json-schema-to-grammar.h"
 #include "unicode.h"
+#include "llama-grammar.h"
+#include "json-schema-to-grammar.h"
+
 #include <cassert>
 #include <string>
 #include <vector>
 
 using json = nlohmann::ordered_json;
 
-static llama_grammar* build_grammar(const std::string & grammar_str) {
-    auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
-
-    // Ensure we parsed correctly
-    assert(!parsed_grammar.rules.empty());
-
-    // Ensure we have a root node
-    assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
-
-    std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
-    llama_grammar* grammar = llama_grammar_init(
-        grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
-
-    return grammar;
+static llama_grammar * build_grammar(const std::string & grammar_str) {
+    return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
 }
 
 static bool test_build_grammar_fails(const std::string & grammar_str) {
@@ -45,25 +30,23 @@ static bool test_build_grammar_fails(const std::string & grammar_str) {
 }
 
 static bool match_string(const std::string & input, llama_grammar * grammar) {
-    auto decoded = decode_utf8(input, {});
-
-    const auto & code_points = decoded.first;
+    const auto cpts = unicode_cpts_from_utf8(input);
 
     const llama_grammar_rules  & rules      = llama_grammar_get_rules (grammar);
-          llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
+          llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
 
-    for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
-        const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
+    for (const auto & cpt : cpts) {
+        const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
 
-        llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
+        llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
 
-        if (cur_stacks.empty()) {
+        if (stacks_cur.empty()) {
             // no stacks means that the grammar failed to match at this point
             return false;
         }
     }
 
-    for (const auto & stack : cur_stacks) {
+    for (const auto & stack : stacks_cur) {
         if (stack.empty()) {
             // An empty stack means that the grammar has been completed
             return true;
@@ -77,12 +60,12 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
     fprintf(stderr, "âš« Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str());
     fflush(stderr);
 
-    auto grammar = build_grammar(grammar_str);
+    auto grammar = build_grammar(grammar_str);
 
     // Save the original grammar stacks so that we can reset after every new string we want to test
-    const llama_grammar_stacks original_stacks = llama_grammar_get_stacks(grammar);
+    const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar);
 
-    llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
+    llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
 
     fprintf(stderr, "  ðŸ”µ Valid strings:\n");
 
@@ -119,7 +102,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
         assert(matched);
 
         // Reset the grammar stacks
-        cur_stacks = original_stacks;
+        stacks_cur = stacks_org;
     }
 
     fprintf(stderr, "  ðŸŸ  Invalid strings:\n");
@@ -139,11 +122,11 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
         assert(!matched);
 
         // Reset the grammar stacks
-        cur_stacks = original_stacks;
+        stacks_cur = stacks_org;
     }
 
     // Clean up allocated memory
-    llama_grammar_free(grammar);
+    llama_grammar_free_impl(grammar);
 }
 static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
     test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
@@ -683,7 +666,8 @@ static void test_failure_missing_root() {
         term ::= number
         number ::= [0-9]+)""";
 
-    grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
+    llama_grammar_parser parsed_grammar;
+    parsed_grammar.parse(grammar_str.c_str());
 
     // Ensure we parsed correctly
     assert(!parsed_grammar.rules.empty());
@@ -705,7 +689,8 @@ static void test_failure_missing_reference() {
 
     fprintf(stderr, "    Expected error:  ");
 
-    grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
+    llama_grammar_parser parsed_grammar;
+    parsed_grammar.parse(grammar_str.c_str());
 
     // Ensure we did NOT parsed correctly
     assert(parsed_grammar.rules.empty());
index 5df5abb25394c8418eb8813ca4dfa5c3797f000b..259172d999c789e3deb0f25e907978a37b61417a 100644 (file)
@@ -3,7 +3,7 @@
 #endif
 
 #include "llama.h"
-#include "grammar-parser.h"
+#include "llama-grammar.h"
 
 #include <cassert>
 
@@ -22,7 +22,8 @@ static const char * type_str(llama_gretype type) {
 
 static void verify_parsing(const char *grammar_bytes, const std::vector<std::pair<std::string, uint32_t>> expected, const std::vector<llama_grammar_element> &expected_rules) {
     uint32_t index = 0;
-    grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_bytes);
+    llama_grammar_parser parsed_grammar;
+    parsed_grammar.parse(grammar_bytes);
 
     std::map<uint32_t, std::string> symbol_names;
     for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) {
@@ -129,9 +130,10 @@ static void verify_parsing(const char *grammar_bytes, const std::vector<std::pai
     }
 }
 
-static void verify_failure(const char *grammar_bytes) {
+static void verify_failure(const char * grammar_bytes) {
     fprintf(stderr, "Testing expected failure:%s\n", grammar_bytes);
-    auto result = grammar_parser::parse(grammar_bytes);
+    llama_grammar_parser result;
+    result.parse(grammar_bytes);
     assert(result.rules.empty() && "should have failed");
 }
 
index 65486ac5c2d1bdf88ff55da1c2ea469b93c95dbb..3a89598a82edb04777c4a810d9fe962a66727b05 100755 (executable)
@@ -2,14 +2,15 @@
 #undef NDEBUG
 #endif
 
+#include "json-schema-to-grammar.h"
+
+#include "llama-grammar.h"
+
 #include <cassert>
 #include <fstream>
 #include <sstream>
 #include <regex>
 
-#include "json-schema-to-grammar.h"
-#include "grammar-parser.h"
-
 static std::string trim(const std::string & source) {
     std::string s(source);
     s.erase(0,s.find_first_not_of(" \n\r\t"));
@@ -40,7 +41,8 @@ struct TestCase {
     }
     void verify_expectation_parseable() const {
         try {
-            auto state = grammar_parser::parse(expected_grammar.c_str());
+            llama_grammar_parser state;
+            state.parse(expected_grammar.c_str());
             if (state.symbol_ids.find("root") == state.symbol_ids.end()) {
                 throw std::runtime_error("Grammar failed to parse:\n" + expected_grammar);
             }
index 1f3a267b39f9bab8cc61979d7a86549d6837e572..6f1374ca8ed58f1d1ecdeefbe1db42a32ce22fa7 100644 (file)
@@ -2,16 +2,15 @@
 #undef NDEBUG
 #endif
 
-#define LLAMA_API_INTERNAL
 #include "llama.h"
-#include "grammar-parser.h"
+#include "llama-grammar.h"
 
 #include <cassert>
 #include <stdexcept>
 
 int main()
 {
-    grammar_parser::parse_state parsed_grammar;
+    llama_grammar_parser parsed_grammar;
 
     std::vector<std::pair<std::string, uint32_t>> expected = {
         {"expr", 2},
@@ -117,7 +116,7 @@ int main()
     llama_grammar * grammar = NULL;
     std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
 
-    grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+    grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
     if (grammar == nullptr)
     {
         throw std::runtime_error("Failed to initialize llama_grammar");
@@ -174,13 +173,13 @@ int main()
         }};
 
     auto index = 0;
-    for (auto stack : llama_grammar_get_stacks(grammar))
+    for (const llama_grammar_stack & stack : llama_grammar_get_stacks(grammar))
     {
         // compare stack to expected_stack
         for (uint32_t i = 0; i < stack.size(); i++)
         {
-            auto element = stack[i];
-            auto expected_element = expected_stacks[index][i];
+            const llama_grammar_element * element = stack[i];
+            const llama_grammar_element & expected_element = expected_stacks[index][i];
 
             // pretty print error message before asserting
             if (expected_element.type != element->type || expected_element.value != element->value)
@@ -403,6 +402,8 @@ int main()
         delete[] candidate.code_points;
         candidate.code_points = nullptr;
     }
-    llama_grammar_free(grammar);
+
+    llama_grammar_free_impl(grammar);
+
     return 0;
 }
index 6c2a5db9accf2c0b14c146fc866e1c2c52d43146..cc4882d37579a6327c07256dbca362304fd9795d 100644 (file)
@@ -1,5 +1,6 @@
 #include "ggml.h"
 #include "llama.h"
+#include "llama-sampling.h"
 
 #ifdef NDEBUG
 #undef NDEBUG
 #include <string>
 #include <vector>
 
-static void dump(const llama_token_data_array * candidates) {
-    for (size_t i = 0; i < candidates->size; i++) {
-        printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit);
+static void dump(const llama_token_data_array * cur_p) {
+    for (size_t i = 0; i < cur_p->size; i++) {
+        printf("%d: %f (%f)\n", cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
     }
 }
 
-#define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0)
+#define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0)
+
+#define APPLY(__cnstr, __cur_p) do { \
+    auto * cnstr = (__cnstr); \
+    llama_sampler_apply(cnstr, (__cur_p)); \
+    llama_sampler_free(cnstr); \
+} while(0)
 
 static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
     const size_t n_vocab = probs.size();
-    std::vector<llama_token_data> candidates;
-    candidates.reserve(n_vocab);
+
+    std::vector<llama_token_data> cur;
+    cur.reserve(n_vocab);
     for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
         const float logit = logf(probs[token_id]);
-        candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
     }
 
-    llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-    llama_sample_softmax(nullptr, &candidates_p);
-    DUMP(&candidates_p);
-    llama_sample_top_k(nullptr, &candidates_p, k, 1);
-    DUMP(&candidates_p);
+    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+    APPLY(llama_sampler_init_softmax(), &cur_p);
+    DUMP(&cur_p);
+    APPLY(llama_sampler_init_top_k(k), &cur_p);
+    DUMP(&cur_p);
 
-    GGML_ASSERT(candidates_p.size == expected_probs.size());
-    for (size_t i = 0; i < candidates_p.size; i++) {
-        GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
+    GGML_ASSERT(cur_p.size == expected_probs.size());
+    for (size_t i = 0; i < cur_p.size; i++) {
+        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5);
     }
 }
 
 static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
     const size_t n_vocab = probs.size();
-    std::vector<llama_token_data> candidates;
-    candidates.reserve(n_vocab);
+
+    std::vector<llama_token_data> cur;
+    cur.reserve(n_vocab);
     for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
         const float logit = logf(probs[token_id]);
-        candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
     }
 
-    llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-    llama_sample_softmax(nullptr, &candidates_p);
-    DUMP(&candidates_p);
-    llama_sample_top_p(nullptr, &candidates_p, p, 1);
-    DUMP(&candidates_p);
+    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+    APPLY(llama_sampler_init_softmax(), &cur_p);
+    DUMP(&cur_p);
+    APPLY(llama_sampler_init_top_p(p, 1), &cur_p);
+    DUMP(&cur_p);
 
-    GGML_ASSERT(candidates_p.size == expected_probs.size());
-    for (size_t i = 0; i < candidates_p.size; i++) {
-        GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
+    GGML_ASSERT(cur_p.size == expected_probs.size());
+    for (size_t i = 0; i < cur_p.size; i++) {
+        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
     }
 }
 
 static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
     const size_t n_vocab = probs.size();
-    std::vector<llama_token_data> candidates;
-    candidates.reserve(n_vocab);
+
+    std::vector<llama_token_data> cur;
+    cur.reserve(n_vocab);
     for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
         const float logit = logf(probs[token_id]);
-        candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
     }
 
-    llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-    DUMP(&candidates_p);
-    llama_sample_tail_free(nullptr, &candidates_p, z, 1);
-    DUMP(&candidates_p);
+    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+    DUMP(&cur_p);
+    APPLY(llama_sampler_init_tail_free(z, 1), &cur_p);
+    DUMP(&cur_p);
 
-    GGML_ASSERT(candidates_p.size == expected_probs.size());
-    for (size_t i = 0; i < candidates_p.size; i++) {
-        GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
+    GGML_ASSERT(cur_p.size == expected_probs.size());
+    for (size_t i = 0; i < cur_p.size; i++) {
+        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
     }
 }
 
 static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
     const size_t n_vocab = probs.size();
-    std::vector<llama_token_data> candidates;
-    candidates.reserve(n_vocab);
+
+    std::vector<llama_token_data> cur;
+    cur.reserve(n_vocab);
     for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
         const float logit = logf(probs[token_id]);
-        candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
     }
 
-    llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-    DUMP(&candidates_p);
-    llama_sample_min_p(nullptr, &candidates_p, p, 1);
-    DUMP(&candidates_p);
-    llama_sample_softmax(nullptr, &candidates_p);
+    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+    DUMP(&cur_p);
+    APPLY(llama_sampler_init_min_p(p, 1), &cur_p);
+    DUMP(&cur_p);
+    APPLY(llama_sampler_init_softmax(), &cur_p);
 
-    GGML_ASSERT(candidates_p.size == expected_probs.size());
-    for (size_t i = 0; i < candidates_p.size; i++) {
-        GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
+    GGML_ASSERT(cur_p.size == expected_probs.size());
+    for (size_t i = 0; i < cur_p.size; i++) {
+        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
     }
 }
 
 static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
     const size_t n_vocab = probs.size();
-    std::vector<llama_token_data> candidates;
-    candidates.reserve(n_vocab);
+
+    std::vector<llama_token_data> cur;
+    cur.reserve(n_vocab);
     for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
         const float logit = logf(probs[token_id]);
-        candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
     }
 
-    llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-    DUMP(&candidates_p);
-    llama_sample_typical(nullptr, &candidates_p, p, 1);
-    DUMP(&candidates_p);
+    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+    DUMP(&cur_p);
+    APPLY(llama_sampler_init_typical(p, 1), &cur_p);
+    DUMP(&cur_p);
 
-    GGML_ASSERT(candidates_p.size == expected_probs.size());
-    for (size_t i = 0; i < candidates_p.size; i++) {
-        GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
+    GGML_ASSERT(cur_p.size == expected_probs.size());
+    for (size_t i = 0; i < cur_p.size; i++) {
+        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
     }
 }
 
-static void test_repetition_penalties(
+static void test_penalties(
     const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
     const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
 ) {
     GGML_ASSERT(probs.size() == expected_probs.size());
 
     const size_t n_vocab = probs.size();
-    std::vector<llama_token_data> candidates;
-    candidates.reserve(n_vocab);
+
+    std::vector<llama_token_data> cur;
+    cur.reserve(n_vocab);
     for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
         const float logit = logf(probs[token_id]);
-        candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
+    }
+
+    llama_token_cnt token_count;
+    for (size_t i = 0; i < last_tokens.size(); i++) {
+        token_count[last_tokens[i]]++;
     }
 
-    llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
-    llama_sample_softmax(nullptr, &candidates_p);
-    DUMP(&candidates_p);
-    llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
-    llama_sample_softmax(nullptr, &candidates_p);
-    DUMP(&candidates_p);
+    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+    APPLY(llama_sampler_init_softmax(), &cur_p);
+    DUMP(&cur_p);
+    llama_sampler_penalties_impl(&cur_p, token_count, repeat_penalty, alpha_frequency, alpha_presence); // TODO: avoid
+    APPLY(llama_sampler_init_softmax(), &cur_p);
+    DUMP(&cur_p);
 
-    GGML_ASSERT(candidates_p.size == expected_probs.size());
-    for (size_t i = 0; i < candidates_p.size; i++) {
-        GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
+    GGML_ASSERT(cur_p.size == expected_probs.size());
+    for (size_t i = 0; i < cur_p.size; i++) {
+        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
     }
 }
 
-static void test_sampler_queue(
-    const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p
+static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
 ) {
-    std::vector<llama_token_data> candidates;
-    candidates.reserve(n_vocab);
+    std::vector<llama_token_data> cur;
+    cur.reserve(n_vocab);
     for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
         const float logit = logf(token_id);
-        candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
+        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
     }
 
-    llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
 
           llama_token min_token_id = 0;
     const llama_token max_token_id = n_vocab-1;
 
     for (auto s : samplers_sequence) {
         switch (s){
-            case 'k': llama_sample_top_k    (nullptr, &candidates_p, top_k, 1); break;
+            case 'k': APPLY(llama_sampler_init_top_k(top_k), &cur_p); break;
             case 'f': GGML_ABORT("tail_free test not implemented");
             case 'y': GGML_ABORT("typical test not implemented");
-            case 'p': llama_sample_top_p    (nullptr, &candidates_p, top_p, 1); break;
-            case 'm': llama_sample_min_p    (nullptr, &candidates_p, min_p, 1); break;
+            case 'p': APPLY(llama_sampler_init_top_p(top_p, 1), &cur_p); break;
+            case 'm': APPLY(llama_sampler_init_min_p(min_p, 1), &cur_p); break;
             case 't': GGML_ABORT("temperature test not implemented");
             default : GGML_ABORT("Unknown sampler");
         }
 
-        llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests
+        APPLY(llama_sampler_init_softmax(), &cur_p); // make sure tokens are sorted for tests
 
-        const int size = candidates_p.size;
+        const int size = cur_p.size;
 
         if (s == 'k') {
             const int expected_size = std::min(size, top_k);
             min_token_id = std::max(min_token_id, (llama_token)(n_vocab - top_k));
 
             GGML_ASSERT(size == expected_size);
-            GGML_ASSERT(candidates_p.data[0].id == max_token_id);
-            GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
+            GGML_ASSERT(cur_p.data[0].id == max_token_id);
+            GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
         } else if (s == 'p') {
             const int softmax_divisor = n_vocab * (n_vocab-1) / 2 - min_token_id * (min_token_id-1) / 2;
             const int softmax_numerator_target = ceilf(top_p * softmax_divisor);
@@ -206,8 +223,8 @@ static void test_sampler_queue(
             }
 
             GGML_ASSERT(size == expected_size);
-            GGML_ASSERT(candidates_p.data[0].id == max_token_id);
-            GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
+            GGML_ASSERT(cur_p.data[0].id == max_token_id);
+            GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
         } else if (s == 'm') {
             int expected_size = ceilf((1.0f-min_p) * n_vocab);
             expected_size = std::max(expected_size, 1);
@@ -219,8 +236,8 @@ static void test_sampler_queue(
             min_token_id = std::min(min_token_id, (llama_token)(n_vocab - 1));
 
             GGML_ASSERT(size == expected_size);
-            GGML_ASSERT(candidates_p.data[0].id == max_token_id);
-            GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
+            GGML_ASSERT(cur_p.data[0].id == max_token_id);
+            GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id);
         } else {
             GGML_ABORT("fatal error");
         }
@@ -259,13 +276,13 @@ int main(void) {
     test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
     test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
 
-    test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0},   50.0f, 0.0f, 0.0f);
-    test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0},       50.0f, 0.0f, 0.0f);
-    test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
+    test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0},   50.0f, 0.0f, 0.0f);
+    test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0},       50.0f, 0.0f, 0.0f);
+    test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
 
-    test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0},             {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
-    test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2},       {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
-    test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
+    test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0},             {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
+    test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2},       {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
+    test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
 
     test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
     test_sampler_queue(10000, "k",     1, 1.0f, 1.0f);