]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
speculative : refactor and add a simpler example (#10362)
authorGeorgi Gerganov <redacted>
Mon, 25 Nov 2024 07:58:41 +0000 (09:58 +0200)
committerGitHub <redacted>
Mon, 25 Nov 2024 07:58:41 +0000 (09:58 +0200)
* speculative : refactor and add a simpler example

ggml-ci

* speculative : clean-up and add comments and TODOs [no ci]

* speculative : manage context in common_speculative

ggml-ci

* speculative : simplify

ggml-ci

* speculative : simplify (cont)

ggml-ci

* speculative : add --draft-min CLI arg

* speculative : minor fixup

* make : build fixes

* speculative : do not redraft previous drafts

ggml-ci

* speculative : fix the draft sampling

ggml-ci

* speculative : fix compile warning

* common : refactor args

ggml-ci

* common : change defaults [no ci]

* common : final touches

ggml-ci

28 files changed:
Makefile
common/CMakeLists.txt
common/arg.cpp
common/common.cpp
common/common.h
common/sampling.cpp
common/sampling.h
common/speculative.cpp [new file with mode: 0644]
common/speculative.h [new file with mode: 0644]
examples/CMakeLists.txt
examples/batched/batched.cpp
examples/infill/infill.cpp
examples/llava/llava-cli.cpp
examples/llava/minicpmv-cli.cpp
examples/lookahead/lookahead.cpp
examples/lookup/lookup-stats.cpp
examples/lookup/lookup.cpp
examples/main/main.cpp
examples/parallel/parallel.cpp
examples/retrieval/retrieval.cpp
examples/save-load-state/save-load-state.cpp
examples/server/server.cpp
examples/server/utils.hpp
examples/speculative-simple/CMakeLists.txt [new file with mode: 0644]
examples/speculative-simple/README.md [new file with mode: 0644]
examples/speculative-simple/speculative-simple.cpp [new file with mode: 0644]
examples/speculative/speculative.cpp
tests/test-arg-parser.cpp

index 5c899438515e14c4a1454b33fd68e3bc7ddd75a8..dd6d864ad513a92f48ce022e7d9999336e72e783 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -966,6 +966,7 @@ OBJ_COMMON = \
        $(DIR_COMMON)/console.o \
        $(DIR_COMMON)/ngram-cache.o \
        $(DIR_COMMON)/sampling.o \
+       $(DIR_COMMON)/speculative.o \
        $(DIR_COMMON)/build-info.o \
        $(DIR_COMMON)/json-schema-to-grammar.o
 
index 5ab1ffa1922aa9a69aa1ae99aace11c309f13c1d..62a8a7db5652f5006133b15fc7ca010d8ca31c03 100644 (file)
@@ -66,6 +66,8 @@ add_library(${TARGET} STATIC
     ngram-cache.h
     sampling.cpp
     sampling.h
+    speculative.cpp
+    speculative.h
     )
 
 if (BUILD_SHARED_LIBS)
index 4115b2f7511d38fdf133d1631c709c15bb154851..32240f21f2469cd6f09132aa07edece1e8d5d8b8 100644 (file)
@@ -233,10 +233,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
         }
     }
 
-    postprocess_cpu_params(params.cpuparams, nullptr);
+    postprocess_cpu_params(params.cpuparams,       nullptr);
     postprocess_cpu_params(params.cpuparams_batch, &params.cpuparams);
-    postprocess_cpu_params(params.draft_cpuparams, &params.cpuparams);
-    postprocess_cpu_params(params.draft_cpuparams_batch, &params.cpuparams_batch);
+
+    postprocess_cpu_params(params.speculative.cpuparams,       &params.cpuparams);
+    postprocess_cpu_params(params.speculative.cpuparams_batch, &params.cpuparams_batch);
 
     if (params.prompt_cache_all && (params.interactive || params.interactive_first)) {
         throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
@@ -251,7 +252,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
         for (auto & antiprompt : params.antiprompt) {
             string_process_escapes(antiprompt);
         }
-        for (auto & seq_breaker : params.sparams.dry_sequence_breakers) {
+        for (auto & seq_breaker : params.sampling.dry_sequence_breakers) {
             string_process_escapes(seq_breaker);
         }
     }
@@ -329,7 +330,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
 
     std::string sampler_type_chars;
     std::string sampler_type_names;
-    for (const auto & sampler : params.sparams.samplers) {
+    for (const auto & sampler : params.sampling.samplers) {
         sampler_type_chars += common_sampler_type_to_chr(sampler);
         sampler_type_names += common_sampler_type_to_str(sampler) + ";";
     }
@@ -407,26 +408,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             }
         }
     ));
-    add_opt(common_arg(
-        {"-td", "--threads-draft"}, "N",
-        "number of threads to use during generation (default: same as --threads)",
-        [](common_params & params, int value) {
-            params.draft_cpuparams.n_threads = value;
-            if (params.draft_cpuparams.n_threads <= 0) {
-                params.draft_cpuparams.n_threads = std::thread::hardware_concurrency();
-            }
-        }
-    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
-    add_opt(common_arg(
-        {"-tbd", "--threads-batch-draft"}, "N",
-        "number of threads to use during batch and prompt processing (default: same as --threads-draft)",
-        [](common_params & params, int value) {
-            params.draft_cpuparams_batch.n_threads = value;
-            if (params.draft_cpuparams_batch.n_threads <= 0) {
-                params.draft_cpuparams_batch.n_threads = std::thread::hardware_concurrency();
-            }
-        }
-    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
     add_opt(common_arg(
         {"-C", "--cpu-mask"}, "M",
         "CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")",
@@ -515,108 +496,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.cpuparams_batch.poll = value;
         }
     ));
-    add_opt(common_arg(
-        {"-Cd", "--cpu-mask-draft"}, "M",
-        "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)",
-        [](common_params & params, const std::string & mask) {
-            params.draft_cpuparams.mask_valid = true;
-            if (!parse_cpu_mask(mask, params.draft_cpuparams.cpumask)) {
-                throw std::invalid_argument("invalid cpumask");
-            }
-        }
-    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
-    add_opt(common_arg(
-        {"-Crd", "--cpu-range-draft"}, "lo-hi",
-        "Ranges of CPUs for affinity. Complements --cpu-mask-draft",
-        [](common_params & params, const std::string & range) {
-            params.draft_cpuparams.mask_valid = true;
-            if (!parse_cpu_range(range, params.draft_cpuparams.cpumask)) {
-                throw std::invalid_argument("invalid range");
-            }
-        }
-    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
-    add_opt(common_arg(
-        {"--cpu-strict-draft"}, "<0|1>",
-        "Use strict CPU placement for draft model (default: same as --cpu-strict)",
-        [](common_params & params, int value) {
-            params.draft_cpuparams.strict_cpu = value;
-        }
-    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
-    add_opt(common_arg(
-        {"--prio-draft"}, "N",
-        string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.draft_cpuparams.priority),
-        [](common_params & params, int prio) {
-            if (prio < 0 || prio > 3) {
-                throw std::invalid_argument("invalid value");
-            }
-            params.draft_cpuparams.priority = (enum ggml_sched_priority) prio;
-        }
-    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
-    add_opt(common_arg(
-        {"--poll-draft"}, "<0|1>",
-        "Use polling to wait for draft model work (default: same as --poll])",
-        [](common_params & params, int value) {
-            params.draft_cpuparams.poll = value;
-        }
-    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
-    add_opt(common_arg(
-        {"-Cbd", "--cpu-mask-batch-draft"}, "M",
-        "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)",
-        [](common_params & params, const std::string & mask) {
-            params.draft_cpuparams_batch.mask_valid = true;
-            if (!parse_cpu_mask(mask, params.draft_cpuparams_batch.cpumask)) {
-                throw std::invalid_argument("invalid cpumask");
-            }
-        }
-    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
-    add_opt(common_arg(
-        {"-Crbd", "--cpu-range-batch-draft"}, "lo-hi",
-        "Ranges of CPUs for affinity. Complements --cpu-mask-draft-batch)",
-        [](common_params & params, const std::string & range) {
-            params.draft_cpuparams_batch.mask_valid = true;
-            if (!parse_cpu_range(range, params.draft_cpuparams_batch.cpumask)) {
-                throw std::invalid_argument("invalid cpumask");
-            }
-        }
-    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
-    add_opt(common_arg(
-        {"--cpu-strict-batch-draft"}, "<0|1>",
-        "Use strict CPU placement for draft model (default: --cpu-strict-draft)",
-        [](common_params & params, int value) {
-            params.draft_cpuparams_batch.strict_cpu = value;
-        }
-    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
-    add_opt(common_arg(
-        {"--prio-batch-draft"}, "N",
-        string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.draft_cpuparams_batch.priority),
-        [](common_params & params, int prio) {
-            if (prio < 0 || prio > 3) {
-                throw std::invalid_argument("invalid value");
-            }
-            params.draft_cpuparams_batch.priority = (enum ggml_sched_priority) prio;
-        }
-    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
-    add_opt(common_arg(
-        {"--poll-batch-draft"}, "<0|1>",
-        "Use polling to wait for draft model work (default: --poll-draft)",
-        [](common_params & params, int value) {
-            params.draft_cpuparams_batch.poll = value;
-        }
-    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
-    add_opt(common_arg(
-        {"--draft"}, "N",
-        string_format("number of tokens to draft for speculative decoding (default: %d)", params.n_draft),
-        [](common_params & params, int value) {
-            params.n_draft = value;
-        }
-    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP}));
-    add_opt(common_arg(
-        {"-ps", "--p-split"}, "N",
-        string_format("speculative decoding split probability (default: %.1f)", (double)params.p_split),
-        [](common_params & params, const std::string & value) {
-            params.p_split = std::stof(value);
-        }
-    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
     add_opt(common_arg(
         {"-lcs", "--lookup-cache-static"}, "FNAME",
         "path to static lookup cache to use for lookup decoding (not updated by generation)",
@@ -701,7 +580,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         string_format("disable internal libllama performance timings (default: %s)", params.no_perf ? "true" : "false"),
         [](common_params & params) {
             params.no_perf = true;
-            params.sparams.no_perf = true;
+            params.sampling.no_perf = true;
         }
     ).set_env("LLAMA_ARG_NO_PERF"));
     add_opt(common_arg(
@@ -883,155 +762,155 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()),
         [](common_params & params, const std::string & value) {
             const auto sampler_names = string_split<std::string>(value, ';');
-            params.sparams.samplers = common_sampler_types_from_names(sampler_names, true);
+            params.sampling.samplers = common_sampler_types_from_names(sampler_names, true);
         }
     ).set_sparam());
     add_opt(common_arg(
         {"-s", "--seed"}, "SEED",
-        string_format("RNG seed (default: %d, use random seed for %d)", params.sparams.seed, LLAMA_DEFAULT_SEED),
+        string_format("RNG seed (default: %d, use random seed for %d)", params.sampling.seed, LLAMA_DEFAULT_SEED),
         [](common_params & params, const std::string & value) {
-            params.sparams.seed = std::stoul(value);
+            params.sampling.seed = std::stoul(value);
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--sampling-seq"}, "SEQUENCE",
         string_format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()),
         [](common_params & params, const std::string & value) {
-            params.sparams.samplers = common_sampler_types_from_chars(value);
+            params.sampling.samplers = common_sampler_types_from_chars(value);
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--ignore-eos"},
         "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)",
         [](common_params & params) {
-            params.sparams.ignore_eos = true;
+            params.sampling.ignore_eos = true;
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--penalize-nl"},
-        string_format("penalize newline tokens (default: %s)", params.sparams.penalize_nl ? "true" : "false"),
+        string_format("penalize newline tokens (default: %s)", params.sampling.penalize_nl ? "true" : "false"),
         [](common_params & params) {
-            params.sparams.penalize_nl = true;
+            params.sampling.penalize_nl = true;
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--temp"}, "N",
-        string_format("temperature (default: %.1f)", (double)params.sparams.temp),
+        string_format("temperature (default: %.1f)", (double)params.sampling.temp),
         [](common_params & params, const std::string & value) {
-            params.sparams.temp = std::stof(value);
-            params.sparams.temp = std::max(params.sparams.temp, 0.0f);
+            params.sampling.temp = std::stof(value);
+            params.sampling.temp = std::max(params.sampling.temp, 0.0f);
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--top-k"}, "N",
-        string_format("top-k sampling (default: %d, 0 = disabled)", params.sparams.top_k),
+        string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k),
         [](common_params & params, int value) {
-            params.sparams.top_k = value;
+            params.sampling.top_k = value;
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--top-p"}, "N",
-        string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sparams.top_p),
+        string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
         [](common_params & params, const std::string & value) {
-            params.sparams.top_p = std::stof(value);
+            params.sampling.top_p = std::stof(value);
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--min-p"}, "N",
-        string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sparams.min_p),
+        string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p),
         [](common_params & params, const std::string & value) {
-            params.sparams.min_p = std::stof(value);
+            params.sampling.min_p = std::stof(value);
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--xtc-probability"}, "N",
-        string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sparams.xtc_probability),
+        string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
         [](common_params & params, const std::string & value) {
-            params.sparams.xtc_probability = std::stof(value);
+            params.sampling.xtc_probability = std::stof(value);
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--xtc-threshold"}, "N",
-        string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sparams.xtc_threshold),
+        string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
         [](common_params & params, const std::string & value) {
-            params.sparams.xtc_threshold = std::stof(value);
+            params.sampling.xtc_threshold = std::stof(value);
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--typical"}, "N",
-        string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p),
+        string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sampling.typ_p),
         [](common_params & params, const std::string & value) {
-            params.sparams.typ_p = std::stof(value);
+            params.sampling.typ_p = std::stof(value);
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--repeat-last-n"}, "N",
-        string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sparams.penalty_last_n),
+        string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sampling.penalty_last_n),
         [](common_params & params, int value) {
-            params.sparams.penalty_last_n = value;
-            params.sparams.n_prev = std::max(params.sparams.n_prev, params.sparams.penalty_last_n);
+            params.sampling.penalty_last_n = value;
+            params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--repeat-penalty"}, "N",
-        string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sparams.penalty_repeat),
+        string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
         [](common_params & params, const std::string & value) {
-            params.sparams.penalty_repeat = std::stof(value);
+            params.sampling.penalty_repeat = std::stof(value);
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--presence-penalty"}, "N",
-        string_format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sparams.penalty_present),
+        string_format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_present),
         [](common_params & params, const std::string & value) {
-            params.sparams.penalty_present = std::stof(value);
+            params.sampling.penalty_present = std::stof(value);
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--frequency-penalty"}, "N",
-        string_format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sparams.penalty_freq),
+        string_format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_freq),
         [](common_params & params, const std::string & value) {
-            params.sparams.penalty_freq = std::stof(value);
+            params.sampling.penalty_freq = std::stof(value);
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--dry-multiplier"}, "N",
-        string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sparams.dry_multiplier),
+        string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sampling.dry_multiplier),
         [](common_params & params, const std::string & value) {
-            params.sparams.dry_multiplier = std::stof(value);
+            params.sampling.dry_multiplier = std::stof(value);
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--dry-base"}, "N",
-        string_format("set DRY sampling base value (default: %.2f)", (double)params.sparams.dry_base),
+        string_format("set DRY sampling base value (default: %.2f)", (double)params.sampling.dry_base),
         [](common_params & params, const std::string & value) {
             float potential_base = std::stof(value);
             if (potential_base >= 1.0f)
             {
-                params.sparams.dry_base = potential_base;
+                params.sampling.dry_base = potential_base;
             }
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--dry-allowed-length"}, "N",
-        string_format("set allowed length for DRY sampling (default: %d)", params.sparams.dry_allowed_length),
+        string_format("set allowed length for DRY sampling (default: %d)", params.sampling.dry_allowed_length),
         [](common_params & params, int value) {
-            params.sparams.dry_allowed_length = value;
+            params.sampling.dry_allowed_length = value;
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--dry-penalty-last-n"}, "N",
-        string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sparams.dry_penalty_last_n),
+        string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sampling.dry_penalty_last_n),
         [](common_params & params, int value) {
-            params.sparams.dry_penalty_last_n = value;
+            params.sampling.dry_penalty_last_n = value;
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--dry-sequence-breaker"}, "STRING",
         string_format("add sequence breaker for DRY sampling, clearing out default breakers (%s) in the process; use \"none\" to not use any sequence breakers\n",
-            params.sparams.dry_sequence_breakers.empty() ? "none" :
-            std::accumulate(std::next(params.sparams.dry_sequence_breakers.begin()),
-                params.sparams.dry_sequence_breakers.end(),
-                std::string("'") + (params.sparams.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sparams.dry_sequence_breakers[0]) + "'",
+            params.sampling.dry_sequence_breakers.empty() ? "none" :
+            std::accumulate(std::next(params.sampling.dry_sequence_breakers.begin()),
+                params.sampling.dry_sequence_breakers.end(),
+                std::string("'") + (params.sampling.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sampling.dry_sequence_breakers[0]) + "'",
                 [](const std::string& a, const std::string& b) {
                     std::string formatted_b = (b == "\n") ? "\\n" : b;
                     return a + ", '" + formatted_b + "'";
@@ -1040,51 +919,51 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             static bool defaults_cleared = false;
 
             if (!defaults_cleared) {
-                params.sparams.dry_sequence_breakers.clear();
+                params.sampling.dry_sequence_breakers.clear();
                 defaults_cleared = true;
             }
 
             if (value == "none") {
-                params.sparams.dry_sequence_breakers.clear();
+                params.sampling.dry_sequence_breakers.clear();
             } else {
-                params.sparams.dry_sequence_breakers.emplace_back(value);
+                params.sampling.dry_sequence_breakers.emplace_back(value);
             }
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--dynatemp-range"}, "N",
-        string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range),
+        string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range),
         [](common_params & params, const std::string & value) {
-            params.sparams.dynatemp_range = std::stof(value);
+            params.sampling.dynatemp_range = std::stof(value);
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--dynatemp-exp"}, "N",
-        string_format("dynamic temperature exponent (default: %.1f)", (double)params.sparams.dynatemp_exponent),
+        string_format("dynamic temperature exponent (default: %.1f)", (double)params.sampling.dynatemp_exponent),
         [](common_params & params, const std::string & value) {
-            params.sparams.dynatemp_exponent = std::stof(value);
+            params.sampling.dynatemp_exponent = std::stof(value);
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--mirostat"}, "N",
         string_format("use Mirostat sampling.\nTop K, Nucleus and Locally Typical samplers are ignored if used.\n"
-        "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sparams.mirostat),
+        "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat),
         [](common_params & params, int value) {
-            params.sparams.mirostat = value;
+            params.sampling.mirostat = value;
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--mirostat-lr"}, "N",
-        string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sparams.mirostat_eta),
+        string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta),
         [](common_params & params, const std::string & value) {
-            params.sparams.mirostat_eta = std::stof(value);
+            params.sampling.mirostat_eta = std::stof(value);
         }
     ).set_sparam());
     add_opt(common_arg(
         {"--mirostat-ent"}, "N",
-        string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sparams.mirostat_tau),
+        string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau),
         [](common_params & params, const std::string & value) {
-            params.sparams.mirostat_tau = std::stof(value);
+            params.sampling.mirostat_tau = std::stof(value);
         }
     ).set_sparam());
     add_opt(common_arg(
@@ -1100,7 +979,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             try {
                 if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
                     const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
-                    params.sparams.logit_bias.push_back({key, bias});
+                    params.sampling.logit_bias.push_back({key, bias});
                 } else {
                     throw std::invalid_argument("invalid input format");
                 }
@@ -1111,9 +990,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
     ).set_sparam());
     add_opt(common_arg(
         {"--grammar"}, "GRAMMAR",
-        string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sparams.grammar.c_str()),
+        string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()),
         [](common_params & params, const std::string & value) {
-            params.sparams.grammar = value;
+            params.sampling.grammar = value;
         }
     ).set_sparam());
     add_opt(common_arg(
@@ -1127,7 +1006,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             std::copy(
                 std::istreambuf_iterator<char>(file),
                 std::istreambuf_iterator<char>(),
-                std::back_inserter(params.sparams.grammar)
+                std::back_inserter(params.sampling.grammar)
             );
         }
     ).set_sparam());
@@ -1135,7 +1014,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         {"-j", "--json-schema"}, "SCHEMA",
         "JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead",
         [](common_params & params, const std::string & value) {
-            params.sparams.grammar = json_schema_to_grammar(json::parse(value));
+            params.sampling.grammar = json_schema_to_grammar(json::parse(value));
         }
     ).set_sparam());
     add_opt(common_arg(
@@ -1444,17 +1323,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             }
         }
     ).set_env("LLAMA_ARG_N_GPU_LAYERS"));
-    add_opt(common_arg(
-        {"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N",
-        "number of layers to store in VRAM for the draft model",
-        [](common_params & params, int value) {
-            params.n_gpu_layers_draft = value;
-            if (!llama_supports_gpu_offload()) {
-                fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n");
-                fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
-            }
-        }
-    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
     add_opt(common_arg(
         {"-sm", "--split-mode"}, "{none,layer,row}",
         "how to split the model across multiple GPUs, one of:\n"
@@ -1593,13 +1461,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.model = value;
         }
     ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL"));
-    add_opt(common_arg(
-        {"-md", "--model-draft"}, "FNAME",
-        "draft model for speculative decoding (default: unused)",
-        [](common_params & params, const std::string & value) {
-            params.model_draft = value;
-        }
-    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
     add_opt(common_arg(
         {"-mu", "--model-url"}, "MODEL_URL",
         "model download url (default: unused)",
@@ -2037,5 +1898,168 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         }
     ).set_env("LLAMA_LOG_TIMESTAMPS"));
 
+    // speculative parameters
+    add_opt(common_arg(
+        {"-td", "--threads-draft"}, "N",
+        "number of threads to use during generation (default: same as --threads)",
+        [](common_params & params, int value) {
+            params.speculative.cpuparams.n_threads = value;
+            if (params.speculative.cpuparams.n_threads <= 0) {
+                params.speculative.cpuparams.n_threads = std::thread::hardware_concurrency();
+            }
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
+    add_opt(common_arg(
+        {"-tbd", "--threads-batch-draft"}, "N",
+        "number of threads to use during batch and prompt processing (default: same as --threads-draft)",
+        [](common_params & params, int value) {
+            params.speculative.cpuparams_batch.n_threads = value;
+            if (params.speculative.cpuparams_batch.n_threads <= 0) {
+                params.speculative.cpuparams_batch.n_threads = std::thread::hardware_concurrency();
+            }
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
+    add_opt(common_arg(
+        {"-Cd", "--cpu-mask-draft"}, "M",
+        "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)",
+        [](common_params & params, const std::string & mask) {
+            params.speculative.cpuparams.mask_valid = true;
+            if (!parse_cpu_mask(mask, params.speculative.cpuparams.cpumask)) {
+                throw std::invalid_argument("invalid cpumask");
+            }
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
+    add_opt(common_arg(
+        {"-Crd", "--cpu-range-draft"}, "lo-hi",
+        "Ranges of CPUs for affinity. Complements --cpu-mask-draft",
+        [](common_params & params, const std::string & range) {
+            params.speculative.cpuparams.mask_valid = true;
+            if (!parse_cpu_range(range, params.speculative.cpuparams.cpumask)) {
+                throw std::invalid_argument("invalid range");
+            }
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
+    add_opt(common_arg(
+        {"--cpu-strict-draft"}, "<0|1>",
+        "Use strict CPU placement for draft model (default: same as --cpu-strict)",
+        [](common_params & params, int value) {
+            params.speculative.cpuparams.strict_cpu = value;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
+    add_opt(common_arg(
+        {"--prio-draft"}, "N",
+        string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.speculative.cpuparams.priority),
+        [](common_params & params, int prio) {
+            if (prio < 0 || prio > 3) {
+                throw std::invalid_argument("invalid value");
+            }
+            params.speculative.cpuparams.priority = (enum ggml_sched_priority) prio;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
+    add_opt(common_arg(
+        {"--poll-draft"}, "<0|1>",
+        "Use polling to wait for draft model work (default: same as --poll])",
+        [](common_params & params, int value) {
+            params.speculative.cpuparams.poll = value;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
+    add_opt(common_arg(
+        {"-Cbd", "--cpu-mask-batch-draft"}, "M",
+        "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)",
+        [](common_params & params, const std::string & mask) {
+            params.speculative.cpuparams_batch.mask_valid = true;
+            if (!parse_cpu_mask(mask, params.speculative.cpuparams_batch.cpumask)) {
+                throw std::invalid_argument("invalid cpumask");
+            }
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
+    add_opt(common_arg(
+        {"-Crbd", "--cpu-range-batch-draft"}, "lo-hi",
+        "Ranges of CPUs for affinity. Complements --cpu-mask-draft-batch)",
+        [](common_params & params, const std::string & range) {
+            params.speculative.cpuparams_batch.mask_valid = true;
+            if (!parse_cpu_range(range, params.speculative.cpuparams_batch.cpumask)) {
+                throw std::invalid_argument("invalid cpumask");
+            }
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
+    add_opt(common_arg(
+        {"--cpu-strict-batch-draft"}, "<0|1>",
+        "Use strict CPU placement for draft model (default: --cpu-strict-draft)",
+        [](common_params & params, int value) {
+            params.speculative.cpuparams_batch.strict_cpu = value;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
+    add_opt(common_arg(
+        {"--prio-batch-draft"}, "N",
+        string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.speculative.cpuparams_batch.priority),
+        [](common_params & params, int prio) {
+            if (prio < 0 || prio > 3) {
+                throw std::invalid_argument("invalid value");
+            }
+            params.speculative.cpuparams_batch.priority = (enum ggml_sched_priority) prio;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
+    add_opt(common_arg(
+        {"--poll-batch-draft"}, "<0|1>",
+        "Use polling to wait for draft model work (default: --poll-draft)",
+        [](common_params & params, int value) {
+            params.speculative.cpuparams_batch.poll = value;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
+    add_opt(common_arg(
+        {"--draft-max", "--draft", "--draft-n"}, "N",
+        string_format("number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max),
+        [](common_params & params, int value) {
+            params.speculative.n_max = value;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
+    add_opt(common_arg(
+        {"--draft-min", "--draft-n-min"}, "N",
+        string_format("minimum number of draft tokens to use for speculative decoding (default: %d)", params.speculative.n_min),
+        [](common_params & params, int value) {
+            params.speculative.n_min = value;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
+    add_opt(common_arg(
+        {"--draft-p-split"}, "P",
+        string_format("speculative decoding split probability (default: %.1f)", (double)params.speculative.p_split),
+        [](common_params & params, const std::string & value) {
+            params.speculative.p_split = std::stof(value);
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
+    add_opt(common_arg(
+        {"--draft-p-min"}, "P",
+        string_format("minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min),
+        [](common_params & params, const std::string & value) {
+            params.speculative.p_min = std::stof(value);
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
+    add_opt(common_arg(
+        {"-cd", "--ctx-size-draft"}, "N",
+        string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx),
+        [](common_params & params, int value) {
+            params.speculative.n_ctx = value;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
+    add_opt(common_arg(
+        {"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N",
+        "number of layers to store in VRAM for the draft model",
+        [](common_params & params, int value) {
+            params.speculative.n_gpu_layers = value;
+            if (!llama_supports_gpu_offload()) {
+                fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n");
+                fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
+            }
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
+    add_opt(common_arg(
+        {"-md", "--model-draft"}, "FNAME",
+        "draft model for speculative decoding (default: unused)",
+        [](common_params & params, const std::string & value) {
+            params.speculative.model = value;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
+
     return ctx_arg;
 }
index d314523db4c62510842c6d84beaf10a67a7a9355..c398329d05bf531538afecc0cacded5c2401d040 100644 (file)
@@ -536,12 +536,12 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
                     [](const unsigned char c) { return !std::isprint(c); }),
                 detokenized.end());
 
-        buf << "\n" << std::to_string(i)
-            << ":token '" << detokenized << "'"
-            << ":pos " << std::to_string(batch.pos[i])
-            << ":n_seq_id  " << std::to_string(batch.n_seq_id[i])
-            << ":seq_id " << std::to_string(batch.seq_id[i][0])
-            << ":logits " << std::to_string(batch.logits[i]);
+        buf << "\n"          << std::to_string(i)
+            << ", token '"   << detokenized << "'"
+            << ", pos "      << std::to_string(batch.pos[i])
+            << ", n_seq_id " << std::to_string(batch.n_seq_id[i])
+            << ", seq_id "   << std::to_string(batch.seq_id[i][0])
+            << ", logits "   << std::to_string(batch.logits[i]);
     }
 
     buf << " ]";
@@ -925,9 +925,9 @@ struct common_init_result common_init_from_params(common_params & params) {
         common_lora_adapters_apply(lctx, iparams.lora_adapters);
     }
 
-    if (params.sparams.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
+    if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
         LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
-        params.sparams.ignore_eos = false;
+        params.sampling.ignore_eos = false;
     }
 
     if (params.warmup) {
@@ -1490,6 +1490,66 @@ void common_batch_add(
     batch.n_tokens++;
 }
 
+//
+// Token utils
+//
+
+size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
+    size_t i;
+    for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
+
+    return i;
+}
+
+size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
+    // check for empty sequences
+    if (a.empty() || b.empty()) {
+        return 0;
+    }
+
+    // get the lengths of the input sequences
+    size_t a_len = a.size();
+    size_t b_len = b.size();
+
+    // initialize the maximum length of the longest common subsequence (LCS)
+    size_t max_length = 0;
+
+    // use two rows instead of a 2D matrix to optimize space
+    std::vector<size_t> prev_row(b_len + 1, 0);
+    std::vector<size_t> curr_row(b_len + 1, 0);
+
+    // iterate through the elements of a
+    for (size_t i = 1; i <= a_len; i++) {
+        // iterate through the elements of b
+        for (size_t j = 1; j <= b_len; j++) {
+            // if elements at the current positions match
+            if (a[i - 1] == b[j - 1]) {
+                // if it's the first element of either sequences, set LCS length to 1
+                if (i == 1 || j == 1) {
+                    curr_row[j] = 1;
+                } else {
+                    // increment LCS length by 1 compared to the previous element
+                    curr_row[j] = prev_row[j - 1] + 1;
+                }
+
+                // update max_length if necessary
+                if (curr_row[j] > max_length) {
+                    max_length = curr_row[j];
+                }
+            } else {
+                // reset LCS length if elements don't match
+                curr_row[j] = 0;
+            }
+        }
+
+        // update the previous row for the next iteration
+        prev_row = curr_row;
+    }
+
+    // return the maximum length of the LCS
+    return max_length;
+}
+
 //
 // Vocab utils
 //
index 7977cc7a99a78758f838b23a8d6b24dbc43a5ce0..5c579b5abfe03f638d321f6645e2ffe77dbfbb50 100644 (file)
@@ -33,6 +33,8 @@ struct common_lora_adapter_container : common_lora_adapter_info {
     struct llama_lora_adapter * adapter;
 };
 
+using llama_tokens = std::vector<llama_token>;
+
 // build info
 extern int LLAMA_BUILD_NUMBER;
 extern char const * LLAMA_COMMIT;
@@ -101,8 +103,8 @@ enum dimre_method {
     DIMRE_METHOD_MEAN,
 };
 
-// sampler parameters
-struct common_sampler_params {
+// sampling parameters
+struct common_params_sampling {
     uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
 
     int32_t n_prev             = 64;    // number of previous tokens to remember
@@ -153,19 +155,30 @@ struct common_sampler_params {
     std::string print() const;
 };
 
+struct common_params_speculative {
+    int32_t n_ctx        =     0; // draft context size
+    int32_t n_max        =    16; // maximum number of tokens to draft during speculative decoding
+    int32_t n_min        =     5; // minimum number of draft tokens to use for speculative decoding
+    int32_t n_gpu_layers =    -1; // number of layers to store in VRAM for the draft model (-1 - use default)
+    float   p_split      =  0.1f; // speculative decoding split probability
+    float   p_min        =  0.9f; // minimum speculative decoding probability (greedy)
+
+    struct cpu_params cpuparams;
+    struct cpu_params cpuparams_batch;
+
+    std::string model = ""; // draft model for speculative decoding                          // NOLINT
+};
+
 struct common_params {
     int32_t n_predict             =    -1; // new tokens to predict
     int32_t n_ctx                 =  4096; // context size
     int32_t n_batch               =  2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
     int32_t n_ubatch              =   512; // physical batch size for prompt processing (must be >=32 to use BLAS)
     int32_t n_keep                =     0; // number of tokens to keep from initial prompt
-    int32_t n_draft               =     5; // number of tokens to draft during speculative decoding
     int32_t n_chunks              =    -1; // max number of chunks to process (-1 = unlimited)
     int32_t n_parallel            =     1; // number of parallel sequences to decode
     int32_t n_sequences           =     1; // number of sequences to decode
-    float   p_split               =  0.1f; // speculative decoding split probability
     int32_t n_gpu_layers          =    -1; // number of layers to store in VRAM (-1 - use default)
-    int32_t n_gpu_layers_draft    =    -1; // number of layers to store in VRAM for the draft model (-1 - use default)
     int32_t main_gpu              =     0; // the GPU that is used for scratch and small tensors
     float   tensor_split[128]     =   {0}; // how split tensors should be distributed across GPUs
     int32_t grp_attn_n            =     1; // group-attention factor
@@ -182,8 +195,6 @@ struct common_params {
 
     struct cpu_params cpuparams;
     struct cpu_params cpuparams_batch;
-    struct cpu_params draft_cpuparams;
-    struct cpu_params draft_cpuparams_batch;
 
     ggml_backend_sched_eval_callback cb_eval = nullptr;
     void * cb_eval_user_data                 = nullptr;
@@ -195,10 +206,10 @@ struct common_params {
     enum llama_pooling_type      pooling_type      = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
     enum llama_attention_type    attention_type    = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
 
-    struct common_sampler_params sparams;
+    struct common_params_sampling sampling;
+    struct common_params_speculative speculative;
 
     std::string model                = ""; // model path                                                    // NOLINT
-    std::string model_draft          = ""; // draft model for speculative decoding                          // NOLINT
     std::string model_alias          = "unknown"; // model alias                                            // NOLINT
     std::string model_url            = ""; // model url to download                                         // NOLINT
     std::string hf_token             = ""; // HF token                                                      // NOLINT
@@ -461,7 +472,9 @@ struct llama_model * common_load_model_from_hf(const char * repo, const char * f
 // clear LoRA adapters from context, then apply new list of adapters
 void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters);
 
+//
 // Batch utils
+//
 
 void common_batch_clear(struct llama_batch & batch);
 
@@ -472,6 +485,16 @@ void common_batch_add(
     const std::vector<llama_seq_id> & seq_ids,
                                bool   logits);
 
+//
+// Token utils
+//
+
+// longest common prefix
+size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
+
+// longet common subsequence
+size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
+
 //
 // Vocab utils
 //
index 7922fde47d3693bfe9fc02c642cb6c37f784b084..0c4699a89c8b24a10fe8638bd6aba9cde2191187 100644 (file)
@@ -99,7 +99,7 @@ struct ring_buffer {
 };
 
 struct common_sampler {
-    common_sampler_params params;
+    common_params_sampling params;
 
     struct llama_sampler * grmr;
     struct llama_sampler * chain;
@@ -125,7 +125,7 @@ struct common_sampler {
     }
 };
 
-std::string common_sampler_params::print() const {
+std::string common_params_sampling::print() const {
     char result[1024];
 
     snprintf(result, sizeof(result),
@@ -141,7 +141,7 @@ std::string common_sampler_params::print() const {
     return std::string(result);
 }
 
-struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params) {
+struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
     llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
 
     lparams.no_perf = params.no_perf;
@@ -320,6 +320,45 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
     return cur_p.data[cur_p.selected].id;
 }
 
+std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
+    GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
+
+    std::vector<llama_token> result;
+    result.reserve(idxs.size());
+
+    size_t i = 0;
+    for (; i < draft.size(); i++) {
+        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
+
+        common_sampler_accept(gsmpl, id, true);
+
+        result.push_back(id);
+
+        if (draft[i] != id) {
+            break;
+        }
+    }
+
+    if (i == draft.size()) {
+        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
+
+        common_sampler_accept(gsmpl, id, true);
+
+        result.push_back(id);
+    }
+
+    return result;
+}
+
+std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
+    std::vector<int> idxs(draft.size() + 1);
+    for (size_t i = 0; i < idxs.size(); ++i) {
+        idxs[i] = i;
+    }
+
+    return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
+}
+
 uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
     return llama_sampler_get_seed(gsmpl->chain);
 }
index d37f25ad37c4a699c30ec16a1bfb954b13403e33..348911b18888b263e0dd90a9f26e56a2308e1846 100644 (file)
@@ -36,7 +36,7 @@ struct common_sampler;
 
 // llama_sampler API overloads
 
-struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params);
+struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
 
 void common_sampler_free(struct common_sampler * gsmpl);
 
@@ -60,6 +60,27 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
 //
 llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
 
+// generalized version of common_sampler_sample
+//
+// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
+// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
+//
+//      common_sampler_sample_n(gsmpl, ctx, { idx }, {});
+//
+// is equivalent to
+//
+//      common_sampler_sample(gsmpl, ctx, idx);
+//      common_sampler_accept(gsmpl, token, true);
+//
+// requires: idxs.size() == draft.size() + 1
+//
+// returns at least 1 token, up to idxs.size()
+//
+std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
+
+// assume idxs == [ 0, 1, 2, ..., draft.size() ]
+std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
+
 uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
 
 // helpers
diff --git a/common/speculative.cpp b/common/speculative.cpp
new file mode 100644 (file)
index 0000000..fe315a2
--- /dev/null
@@ -0,0 +1,269 @@
+#include "speculative.h"
+
+#include "log.h"
+#include "common.h"
+#include "sampling.h"
+
+#include <cstring>
+
+#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE  128
+#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
+
+struct common_speculative {
+    struct llama_context * ctx;
+    struct common_sampler * smpl;
+
+    llama_batch batch;
+    llama_tokens prompt;
+};
+
+struct common_speculative * common_speculative_init(
+        struct llama_context * ctx_dft) {
+    auto * result = new common_speculative {
+        /* .ctx    = */ ctx_dft,
+        /* .smpl   = */ nullptr,
+        /* .batch  = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
+        /* .prompt = */ {},
+    };
+
+    // TODO: optimize or pass from outside?
+#if 0
+    {
+        common_params_sampling params;
+        params.no_perf = false;
+
+        params.top_k = 40;
+        params.top_p = 0.9;
+
+        params.samplers = {
+            COMMON_SAMPLER_TYPE_TOP_K,
+            COMMON_SAMPLER_TYPE_TOP_P,
+            COMMON_SAMPLER_TYPE_INFILL,
+        };
+
+        result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
+    }
+#else
+    {
+        common_params_sampling params;
+        params.no_perf = false;
+
+        params.top_k = 10;
+
+        params.samplers = {
+            COMMON_SAMPLER_TYPE_TOP_K,
+        };
+
+        result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
+    }
+#endif
+
+    return result;
+}
+
+void common_speculative_free(struct common_speculative * spec) {
+    common_sampler_free(spec->smpl);
+
+    llama_batch_free(spec->batch);
+
+    delete spec;
+}
+
+bool common_speculative_are_compatible(
+        const struct llama_context * ctx_tgt,
+        const struct llama_context * ctx_dft) {
+    const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
+    const struct llama_model * model_dft = llama_get_model(ctx_dft);
+
+    const bool vocab_type_tgt = llama_vocab_type(model_tgt);
+    LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
+
+    const bool vocab_type_dft = llama_vocab_type(model_dft);
+    LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
+
+    if (vocab_type_tgt != vocab_type_dft) {
+        LOG_ERR("%s: draft model vocab type must match target model to use speculation but "
+                     "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
+        return false;
+    }
+
+    if (llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
+        llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
+        llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
+        llama_token_eos(model_tgt) != llama_token_eos(model_dft)
+    ) {
+        LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
+        return false;
+    }
+
+    {
+        const int n_vocab_tgt = llama_n_vocab(model_tgt);
+        const int n_vocab_dft = llama_n_vocab(model_dft);
+
+        const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
+
+        if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
+            LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
+                         "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
+                    __func__, n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
+            return false;
+        }
+
+        for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
+            const char * token_text_tgt = llama_token_get_text(model_tgt, i);
+            const char * token_text_dft = llama_token_get_text(model_dft, i);
+            if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
+                LOG_ERR("%s: draft model vocab must match target model to use speculation but "
+                             "token %d content differs - target '%s', draft '%s'\n", __func__, i,
+                        common_token_to_piece(ctx_tgt, i).c_str(),
+                        common_token_to_piece(ctx_dft, i).c_str());
+                return false;
+            }
+        }
+    }
+
+    return true;
+}
+
+llama_tokens common_speculative_gen_draft(
+        struct common_speculative * spec,
+        struct common_speculative_params params,
+        const llama_tokens & prompt_tgt,
+        llama_token id_last) {
+    auto & batch  = spec->batch;
+    auto & ctx    = spec->ctx;
+    auto & smpl   = spec->smpl;
+    auto & prompt = spec->prompt;
+
+    int reuse_i = 0;
+    int reuse_n = 0;
+
+    const int n_ctx = llama_n_ctx(ctx) - params.n_draft;
+
+    const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
+
+    // reuse as much as possible from the old draft context
+    // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
+    for (int i = 0; i < (int) prompt.size(); ++i) {
+        int cur = 0;
+        while (i_start + cur < (int) prompt_tgt.size() &&
+               i       + cur < (int) prompt.size() &&
+               prompt_tgt[i_start + cur] == prompt[i + cur]) {
+            cur++;
+        }
+
+        if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) {
+            reuse_i = i;
+            reuse_n = cur;
+        }
+    }
+
+    LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size());
+
+    llama_tokens result;
+    result.reserve(params.n_draft);
+
+    if (reuse_n == 0) {
+        llama_kv_cache_clear(ctx);
+
+        prompt.clear();
+    } else {
+        // this happens when a previous draft has been discarded (for example, due to being too small), but the
+        // target model agreed with it. in this case, we simply pass back the previous results to save compute
+        if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
+            for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
+                result.push_back(prompt[i]);
+
+                if (params.n_draft <= (int) result.size()) {
+                    break;
+                }
+            }
+
+            return result;
+        }
+
+        if (reuse_i > 0) {
+            llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
+            llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
+
+            prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
+        }
+
+        if (reuse_n < (int) prompt.size()) {
+            llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
+
+            prompt.erase(prompt.begin() + reuse_n, prompt.end());
+        }
+    }
+
+    // prepare a batch to evaluate any new tokens in the prompt
+    common_batch_clear(batch);
+
+    for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
+        //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
+        common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
+
+        prompt.push_back(prompt_tgt[i]);
+    }
+
+    // we should rarely end-up here during normal decoding
+    if (batch.n_tokens > 0) {
+        //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
+
+        llama_decode(ctx, batch);
+    }
+
+    const llama_pos n_past = prompt.size();
+
+    LOG_DBG("%s: n_past = %d\n", __func__, n_past);
+
+    common_batch_clear(batch);
+    common_batch_add  (batch, id_last, n_past, { 0 }, true);
+
+    prompt.push_back(id_last);
+
+    //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
+
+    llama_decode(ctx, batch);
+
+    common_sampler_reset(smpl);
+
+    // sample n_draft tokens from the draft model
+    for (int i = 0; i < params.n_draft; ++i) {
+        common_batch_clear(batch);
+
+        common_sampler_sample(smpl, ctx, 0, true);
+
+        const auto * cur_p = common_sampler_get_candidates(smpl);
+
+        for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
+            LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
+                    k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
+        }
+
+        // add drafted token for each sequence
+        const llama_token id = cur_p->data[0].id;
+
+        // only collect very high-confidence draft tokens
+        if (cur_p->data[0].p < params.p_min) {
+            break;
+        }
+
+        common_sampler_accept(smpl, id, true);
+
+        result.push_back(id);
+
+        if (params.n_draft <= (int) result.size()) {
+            break;
+        }
+
+        common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
+
+        // evaluate the drafted tokens on the draft model
+        llama_decode(ctx, batch);
+
+        prompt.push_back(id);
+    }
+
+    return result;
+}
diff --git a/common/speculative.h b/common/speculative.h
new file mode 100644 (file)
index 0000000..50ec034
--- /dev/null
@@ -0,0 +1,28 @@
+#pragma once
+
+#include "llama.h"
+#include "common.h"
+
+struct common_speculative;
+
+struct common_speculative_params {
+    int n_draft = 16;  // max drafted tokens
+    int n_reuse = 256;
+
+    float p_min = 0.9f; // min probabiliy required to accept a token in the draft
+};
+
+struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
+
+void common_speculative_free(struct common_speculative * spec);
+
+bool common_speculative_are_compatible(
+        const struct llama_context * ctx_tgt,
+        const struct llama_context * ctx_dft);
+
+// sample up to n_draft tokens and add them to the batch using the draft model
+llama_tokens common_speculative_gen_draft(
+               struct common_speculative * spec,
+        struct common_speculative_params   params,
+                      const llama_tokens & prompt,
+                             llama_token   id_last);
index d63a96c1c25475dc44f978fe231bdb088a871f85..9bd099d4ef8a5653fc2758a1158c1f504b83a0fc 100644 (file)
@@ -50,5 +50,6 @@ else()
     add_subdirectory(simple)
     add_subdirectory(simple-chat)
     add_subdirectory(speculative)
+    add_subdirectory(speculative-simple)
     add_subdirectory(tokenize)
 endif()
index 3b554033e7ee4102e641f43c19e97244457d4f34..ba219cd4b32ae88112824cebb16a6fbf708807d7 100644 (file)
@@ -68,10 +68,10 @@ int main(int argc, char ** argv) {
 
     llama_sampler * smpl = llama_sampler_chain_init(sparams);
 
-    llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sparams.top_k));
-    llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sparams.top_p, params.sparams.min_keep));
-    llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sparams.temp));
-    llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sparams.seed));
+    llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
+    llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
+    llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
+    llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
 
     if (ctx == NULL) {
         LOG_ERR("%s: error: failed to create the llama_context\n" , __func__);
index 15b358dc4e854d5101b4c9d4b272f3c8b2ba8f61..ef700895720fffa811304db695b3b55021040ca9 100644 (file)
@@ -73,7 +73,7 @@ int main(int argc, char ** argv) {
 
     common_init();
 
-    auto & sparams = params.sparams;
+    auto & sparams = params.sampling;
 
     console::init(params.simple_io, params.use_color);
     atexit([]() { console::cleanup(); });
index 1610985858fc9b196a1f926130cb120df2bb4bd2..2691c6e6b2dd2c23a82d148ac87b90b5a2b268a1 100644 (file)
@@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
 
     LOG("\n");
 
-    struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sparams);
+    struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
     if (!smpl) {
         LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
         exit(1);
index cbecec343c640f2e40690b494f50bce335e693a9..e9cbb51ed90ab8e743c7dbf00413800045b94066 100644 (file)
@@ -237,7 +237,7 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm
 
     LOG_INF("\n");
 
-    struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sparams);
+    struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
     return smpl;
 }
 
index 3c0ccfea2ccd7d1378686a20bb058fdf16066042..8d0ef8b3d75e6909753f37746c17b67e0dd44fa4 100644 (file)
@@ -115,7 +115,7 @@ int main(int argc, char ** argv) {
     llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
 
     // target model sampling context
-    struct common_sampler * smpl = common_sampler_init(model, params.sparams);
+    struct common_sampler * smpl = common_sampler_init(model, params.sampling);
 
     // verification n-grams
     std::vector<ngram_data> ngrams_cur(G);
index 7faebe7ba11fcf21248c99d3e663745fc71d6868..dff07c075c47fb6e9b61f9d80955be49f2350d1b 100644 (file)
@@ -21,7 +21,7 @@ int main(int argc, char ** argv){
 
     common_init();
 
-    const int n_draft = params.n_draft;
+    const int n_draft = params.speculative.n_max;
 
     // init llama.cpp
     llama_backend_init();
@@ -40,6 +40,7 @@ int main(int argc, char ** argv){
     common_ngram_cache ngram_cache_context;
     common_ngram_cache ngram_cache_dynamic;
     common_ngram_cache ngram_cache_static;
+
     int64_t t_draft_flat_us = 0;
     int64_t t_draft_us = 0;
 
index a04728b1834cc82f576984f70a63e28607105bca..4d92bb23853585d268111fb3f0567f30605a4215 100644 (file)
@@ -22,7 +22,7 @@ int main(int argc, char ** argv){
     common_init();
 
     // max. number of additional tokens to draft if match is found
-    const int n_draft = params.n_draft;
+    const int n_draft = params.speculative.n_max;
 
     const bool dump_kv_cache = params.dump_kv_cache;
 
@@ -102,7 +102,7 @@ int main(int argc, char ** argv){
 
     bool has_eos = false;
 
-    struct common_sampler * smpl = common_sampler_init(model, params.sparams);
+    struct common_sampler * smpl = common_sampler_init(model, params.sampling);
 
     std::vector<llama_token> draft;
 
index 7c4ce4be2abaee16e07622984bf69c9b3538ee28..957451af7ce0ac58a48ca82fc7724d484229afaf 100644 (file)
@@ -100,7 +100,7 @@ int main(int argc, char ** argv) {
 
     common_init();
 
-    auto & sparams = params.sparams;
+    auto & sparams = params.sampling;
 
     // save choice to use color for later
     // (note for later: this is a slightly awkward choice)
index 43c8f3ed56ba9c6d37193efcb05592f5e5f031da..fd2b1c011283822fd6e8d766b5c8cea0326b035c 100644 (file)
@@ -160,7 +160,7 @@ int main(int argc, char ** argv) {
     for (size_t i = 0; i < clients.size(); ++i) {
         auto & client = clients[i];
         client.id = i;
-        client.smpl = common_sampler_init(model, params.sparams);
+        client.smpl = common_sampler_init(model, params.sampling);
     }
 
     std::vector<llama_token> tokens_system;
index 1768aae5100679e7d5c9ce507b696a7b0fe10faa..e78a8596d8cfe41b66a13570c8e0b6164449df46 100644 (file)
@@ -282,8 +282,8 @@ int main(int argc, char ** argv) {
                 return a.second > b.second;
             });
 
-            LOG("Top %d similar chunks:\n", params.sparams.top_k);
-            for (int i = 0; i < std::min(params.sparams.top_k, (int) chunks.size()); i++) {
+            LOG("Top %d similar chunks:\n", params.sampling.top_k);
+            for (int i = 0; i < std::min(params.sampling.top_k, (int) chunks.size()); i++) {
                 LOG("filename: %s\n", chunks[similarities[i].first].filename.c_str());
                 LOG("filepos: %lld\n", (long long int) chunks[similarities[i].first].filepos);
                 LOG("similarity: %f\n", similarities[i].second);
index 8c49a52a661249ef64146728d7673c27d096158a..2f0cf9baa32b71fdc2edfadb6d7b49f681fb4796 100644 (file)
@@ -9,7 +9,7 @@ int main(int argc, char ** argv) {
     common_params params;
 
     params.prompt = "The quick brown fox";
-    params.sparams.seed = 1234;
+    params.sampling.seed = 1234;
 
     if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
         return 1;
@@ -42,7 +42,7 @@ int main(int argc, char ** argv) {
 
     llama_sampler * smpl = llama_sampler_chain_init(sparams);
 
-    llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed));
+    llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sampling.seed));
 
     // tokenize prompt
     auto tokens = common_tokenize(ctx, params.prompt, true);
@@ -106,7 +106,7 @@ int main(int argc, char ** argv) {
 
     llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
 
-    llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed));
+    llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sampling.seed));
 
     printf("\nsecond run: %s", params.prompt.c_str());
 
@@ -169,7 +169,7 @@ int main(int argc, char ** argv) {
 
     llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
 
-    llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed));
+    llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sampling.seed));
 
     printf("\nsingle seq run: %s", params.prompt.c_str());
 
index b8e003be9730e26e3f80304efbb72c4593b05b06..6c55d65c013305eecddbaad34c5fafd6ba20208d 100644 (file)
@@ -175,7 +175,7 @@ struct server_slot {
     // sampling
     json json_schema;
 
-    struct common_sampler_params sparams;
+    struct common_params_sampling sparams;
     struct common_sampler * smpl = nullptr;
 
     llama_token sampled;
@@ -687,7 +687,7 @@ struct server_context {
 
             SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
 
-            slot.sparams = params.sparams;
+            slot.sparams = params.sampling;
 
             slot.callback_on_release = [this](int) {
                 queue_tasks.pop_deferred_task();
@@ -743,7 +743,7 @@ struct server_context {
                 }
 
                 // length of the Longest Common Subsequence between the current slot's prompt and the input prompt
-                int cur_lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens);
+                int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens);
 
                 // fraction of the common subsequence length compared to the current slot's prompt length
                 float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
@@ -788,7 +788,7 @@ struct server_context {
     bool launch_slot_with_task(server_slot & slot, const server_task & task) {
         slot_params default_params;
         // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
-        auto default_sparams = params.sparams;
+        auto default_sparams = params.sampling;
         const auto & data = task.data;
 
         if (data.count("__oaicompat") != 0) {
@@ -1960,7 +1960,7 @@ struct server_context {
 
                             if (slot.params.cache_prompt) {
                                 // reuse any previously computed tokens that are common with the new prompt
-                                slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
+                                slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
 
                                 // reuse chunks from the cached prompt by shifting their KV cache in the new position
                                 if (params.n_cache_reuse > 0) {
index c47ed3e47a76dc35aacc7656367e56505769186d..1665e9dc37db6bec86639cdd6d4a6224af5ab17a 100644 (file)
@@ -24,7 +24,6 @@
 #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
 
 using json = nlohmann::ordered_json;
-using llama_tokens = std::vector<llama_token>;
 
 #define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
 #define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
@@ -439,62 +438,6 @@ static std::string gen_chatcmplid() {
 // other common utils
 //
 
-static size_t longest_common_prefix(const llama_tokens & a, const llama_tokens & b) {
-    size_t i;
-    for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
-
-    return i;
-}
-
-static size_t longest_common_subsequence(const llama_tokens & a, const llama_tokens & b) {
-    // check for empty sequences
-    if (a.empty() || b.empty()) {
-        return 0;
-    }
-
-    // get the lengths of the input sequences
-    size_t a_len = a.size();
-    size_t b_len = b.size();
-
-    // initialize the maximum length of the longest common subsequence (LCS)
-    size_t max_length = 0;
-
-    // use two rows instead of a 2D matrix to optimize space
-    std::vector<size_t> prev_row(b_len + 1, 0);
-    std::vector<size_t> curr_row(b_len + 1, 0);
-
-    // iterate through the elements of a
-    for (size_t i = 1; i <= a_len; i++) {
-        // iterate through the elements of b
-        for (size_t j = 1; j <= b_len; j++) {
-            // if elements at the current positions match
-            if (a[i - 1] == b[j - 1]) {
-                // if it's the first element of either sequences, set LCS length to 1
-                if (i == 1 || j == 1) {
-                    curr_row[j] = 1;
-                } else {
-                    // increment LCS length by 1 compared to the previous element
-                    curr_row[j] = prev_row[j - 1] + 1;
-                }
-
-                // update max_length if necessary
-                if (curr_row[j] > max_length) {
-                    max_length = curr_row[j];
-                }
-            } else {
-                // reset LCS length if elements don't match
-                curr_row[j] = 0;
-            }
-        }
-
-        // update the previous row for the next iteration
-        prev_row = curr_row;
-    }
-
-    // return the maximum length of the LCS
-    return max_length;
-}
-
 static bool ends_with(const std::string & str, const std::string & suffix) {
     return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
 }
diff --git a/examples/speculative-simple/CMakeLists.txt b/examples/speculative-simple/CMakeLists.txt
new file mode 100644 (file)
index 0000000..7a3a141
--- /dev/null
@@ -0,0 +1,5 @@
+set(TARGET llama-speculative-simple)
+add_executable(${TARGET} speculative-simple.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_11)
diff --git a/examples/speculative-simple/README.md b/examples/speculative-simple/README.md
new file mode 100644 (file)
index 0000000..e3a6c6b
--- /dev/null
@@ -0,0 +1,12 @@
+# llama.cpp/examples/speculative-simple
+
+Demonstration of basic greedy speculative decoding
+
+```bash
+./bin/llama-speculative-simple \
+    -m  ../models/qwen2.5-32b-coder-instruct/ggml-model-q8_0.gguf \
+    -md ../models/qwen2.5-1.5b-coder-instruct/ggml-model-q4_0.gguf \
+    -f test.txt -c 0 -ngl 99 --color \
+    --sampling-seq k --top-k 1 -fa --temp 0.0 \
+    -ngld 99 --draft-max 16 --draft-min 5 --draft-p-min 0.9
+```
diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp
new file mode 100644 (file)
index 0000000..1bc7f42
--- /dev/null
@@ -0,0 +1,273 @@
+#include "arg.h"
+#include "common.h"
+#include "sampling.h"
+#include "speculative.h"
+#include "log.h"
+#include "llama.h"
+
+#include <cstdio>
+#include <cstring>
+#include <string>
+#include <vector>
+
+int main(int argc, char ** argv) {
+    common_params params;
+
+    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
+        return 1;
+    }
+
+    if (params.n_predict < -1) {
+        LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
+        return 1;
+    }
+
+    common_init();
+
+    if (params.speculative.model.empty()) {
+        LOG_ERR("%s: --model-draft is required\n", __func__);
+        return 1;
+    }
+
+    // init llama.cpp
+    llama_backend_init();
+    llama_numa_init(params.numa);
+
+    llama_model * model_tgt = NULL;
+    llama_model * model_dft = NULL;
+
+    llama_context * ctx_tgt = NULL;
+    llama_context * ctx_dft = NULL;
+
+    // load the target model
+    common_init_result llama_init_tgt = common_init_from_params(params);
+
+    model_tgt = llama_init_tgt.model;
+    ctx_tgt   = llama_init_tgt.context;
+
+    // load the draft model
+    params.model        = params.speculative.model;
+    params.n_ctx        = params.speculative.n_ctx;
+    params.n_batch      = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch;
+    params.n_gpu_layers = params.speculative.n_gpu_layers;
+
+    if (params.speculative.cpuparams.n_threads > 0) {
+        params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
+    }
+
+    params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
+    common_init_result llama_init_dft = common_init_from_params(params);
+
+    model_dft = llama_init_dft.model;
+    ctx_dft   = llama_init_dft.context;
+
+    if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
+        return 1;
+    }
+
+    // Tokenize the prompt
+    std::vector<llama_token> inp;
+    inp = common_tokenize(ctx_tgt, params.prompt, true, true);
+
+    if (llama_n_ctx(ctx_tgt) < (int) inp.size()) {
+        LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt));
+
+        return 1;
+    }
+
+    if (llama_n_batch(ctx_tgt) < (int) inp.size()) {
+        LOG_ERR("%s: the prompt exceeds the batch size (%d tokens, batch %d)\n", __func__, (int) inp.size(), llama_n_batch(ctx_tgt));
+
+        return 1;
+    }
+
+    LOG("\n\n");
+
+    for (auto id : inp) {
+        LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
+    }
+
+    // how many tokens to draft each time
+    int n_draft     = params.speculative.n_max;
+    int n_draft_min = params.speculative.n_min;
+
+    float p_min = params.speculative.p_min;
+
+    int n_predict = 0;
+    int n_drafted = 0;
+    int n_accept  = 0;
+
+    // used to determine end of generation
+    bool has_eos = false;
+
+    // ================================================
+    // everything until here is standard initialization
+    // the relevant stuff for speculative decoding starts here
+
+    const auto t_enc_start = ggml_time_us();
+
+    // target model sampling context
+    struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
+
+    // eval the prompt
+    llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
+
+    // note: keep the last token separate!
+    llama_token id_last = inp.back();
+
+    // all tokens currently in the target context
+    auto prompt_tgt = std::vector<llama_token>(inp.begin(), inp.end() - 1);
+
+    int n_past = inp.size() - 1;
+
+    // init the speculator
+    struct common_speculative_params params_spec;
+    params_spec.n_draft = n_draft;
+    params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
+    params_spec.p_min   = p_min;
+
+    struct common_speculative * spec = common_speculative_init(ctx_dft);
+
+    llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
+
+    const auto t_enc_end = ggml_time_us();
+
+    const auto t_dec_start = ggml_time_us();
+
+    while (true) {
+        // optionally, generate draft tokens that can be appended to the target batch
+        //
+        // this is the most important part of the speculation. the more probable tokens that are provided here
+        // the better the performance will be. in theory, this computation can be performed asynchronously and even
+        // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
+        // from a cache or lookup tables.
+        //
+        llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last);
+
+        //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
+
+        // always have a token to evaluate from before - id_last
+        common_batch_clear(batch_tgt);
+        common_batch_add  (batch_tgt, id_last, n_past++, { 0 }, true);
+
+        // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
+        {
+            // do not waste time on small drafts
+            if (draft.size() < n_draft_min) {
+                draft.clear();
+            }
+
+            for (size_t i = 0; i < draft.size(); ++i) {
+                common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
+            }
+
+            //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
+
+            llama_decode(ctx_tgt, batch_tgt);
+        }
+
+        // sample from the full target batch and return the accepted tokens based on the target sampler
+        //
+        // for each token to be accepted, the sampler would have to sample that same token
+        // in such cases, instead of decoding the sampled token as we normally do, we simply continue with the
+        // available logits from the batch and sample the next token until we run out of logits or the sampler
+        // disagrees with the draft
+        //
+        const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft);
+
+        //LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str());
+
+        GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
+
+        n_past    += ids.size() - 1;
+        n_drafted += batch_tgt.n_tokens - 1;
+        n_accept  += ids.size() - 1;
+
+        // process the accepted tokens and update contexts
+        //
+        // this is the standard token post-processing that we normally do
+        // in this case, we do it for a group of accepted tokens at once
+        //
+        {
+            llama_token id;
+            std::string token_str;
+
+            for (size_t i = 0; i < ids.size(); ++i) {
+                id = ids[i];
+
+                ++n_predict;
+
+                if (llama_token_is_eog(model_tgt, id)) {
+                    has_eos = true;
+                    break;
+                }
+
+                token_str = common_token_to_piece(ctx_tgt, id);
+
+                if (params.use_color && i + 1 < ids.size()) {
+                    LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
+                } else {
+                    LOG("%s", token_str.c_str());
+                }
+            }
+
+            if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
+                break;
+            }
+
+            LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, (int) draft.size(), id, token_str.c_str());
+
+            {
+                LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
+
+                llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
+            }
+
+            prompt_tgt.push_back(id_last);
+            prompt_tgt.insert(prompt_tgt.end(), ids.begin(), ids.end() - 1);
+
+            // remember the last accepted token for the next iteration
+            id_last = id;
+        }
+    }
+
+    auto t_dec_end = ggml_time_us();
+
+    const int n_input = inp.size();
+
+    LOG("\n\n");
+
+    LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input,   (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
+    LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict  / ((t_dec_end - t_dec_start) / 1e6f));
+
+    LOG_INF("\n");
+    LOG_INF("n_draft   = %d\n", n_draft);
+    LOG_INF("n_predict = %d\n", n_predict);
+    LOG_INF("n_drafted = %d\n", n_drafted);
+    LOG_INF("n_accept  = %d\n", n_accept);
+    LOG_INF("accept    = %.3f%%\n", 100.0f * n_accept / n_drafted);
+
+    LOG_INF("\n");
+    LOG_INF("draft:\n\n");
+
+    llama_perf_context_print(ctx_dft);
+
+    LOG_INF("\n");
+    LOG_INF("target:\n\n");
+    common_perf_print(ctx_tgt, smpl);
+
+    common_sampler_free(smpl);
+    common_speculative_free(spec);
+
+    llama_free(ctx_tgt);
+    llama_free_model(model_tgt);
+
+    llama_free(ctx_dft);
+    llama_free_model(model_dft);
+
+    llama_backend_free();
+
+    LOG("\n\n");
+
+    return 0;
+}
index 6cafd8a8379922fda2e98aacaebdccc8f8bf2012..eb8bb2de54a268e1465814e720ae98edb1bde1dc 100644 (file)
@@ -12,7 +12,7 @@
 #include <string>
 #include <vector>
 
-#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE  100
+#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE  128
 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
 
 struct seq_draft {
@@ -33,7 +33,7 @@ int main(int argc, char ** argv) {
     common_params params;
 
     // needed to get candidate probs even for temp <= 0.0
-    params.sparams.n_probs = 128;
+    params.sampling.n_probs = 128;
 
     if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
         return 1;
@@ -46,7 +46,7 @@ int main(int argc, char ** argv) {
 
     common_init();
 
-    if (params.model_draft.empty()) {
+    if (params.speculative.model.empty()) {
         LOG_ERR("%s: --model-draft is required\n", __func__);
         return 1;
     }
@@ -55,9 +55,9 @@ int main(int argc, char ** argv) {
     const int n_seq_dft = params.n_parallel;
 
     // probability threshold for splitting a draft branch (only for n_seq_dft > 1)
-    const float p_split  = params.p_split;
+    const float p_draft_split = params.speculative.p_split;
 
-    std::default_random_engine rng(params.sparams.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sparams.seed);
+    std::default_random_engine rng(params.sampling.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sampling.seed);
     std::uniform_real_distribution<> u_dist;
 
     // init llama.cpp
@@ -76,13 +76,13 @@ int main(int argc, char ** argv) {
     ctx_tgt = llama_init_tgt.context;
 
     // load the draft model
-    params.model = params.model_draft;
-    params.n_gpu_layers = params.n_gpu_layers_draft;
-    if (params.draft_cpuparams.n_threads > 0) {
-        params.cpuparams.n_threads = params.draft_cpuparams.n_threads;
+    params.model = params.speculative.model;
+    params.n_gpu_layers = params.speculative.n_gpu_layers;
+    if (params.speculative.cpuparams.n_threads > 0) {
+        params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
     }
 
-    params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads;
+    params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
     common_init_result llama_init_dft = common_init_from_params(params);
     model_dft = llama_init_dft.model;
     ctx_dft = llama_init_dft.context;
@@ -170,7 +170,7 @@ int main(int argc, char ** argv) {
     //GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
 
     // how many tokens to draft each time
-    int n_draft = params.n_draft;
+    int n_draft = params.speculative.n_max;
 
     int n_predict = 0;
     int n_drafted = 0;
@@ -183,14 +183,14 @@ int main(int argc, char ** argv) {
     bool has_eos = false;
 
     // target model sampling context (reuse the llama_context's sampling instance)
-    struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams);
+    struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
 
     // draft sequence data
     std::vector<seq_draft> drafts(n_seq_dft);
 
     for (int s = 0; s < n_seq_dft; ++s) {
         // allocate llama_sampler for each draft sequence
-        drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
+        drafts[s].smpl = common_sampler_init(model_dft, params.sampling);
     }
 
     llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
@@ -230,7 +230,7 @@ int main(int argc, char ** argv) {
             // for stochastic sampling, attempt to match the token with the drafted tokens
             {
                 bool accept = false;
-                if (params.sparams.temp > 0) {
+                if (params.sampling.temp > 0) {
                     // stochastic verification
                     common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
 
@@ -494,7 +494,7 @@ int main(int argc, char ** argv) {
 
                 // attempt to split the branch if the probability is high enough
                 for (int f = 1; f < 8; ++f) {
-                    if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) {
+                    if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) {
                         LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur);
 
                         llama_kv_cache_seq_rm(ctx_dft,    n_seq_cur, -1, -1);
index 3665238b5a2d89d6d522c2ddb4b051ef78f99117..69604b87ceec4376b5d2a0e17250d50b3e61fa16 100644 (file)
@@ -70,7 +70,7 @@ int main(void) {
 
     // non-existence arg in specific example (--draft cannot be used outside llama-speculative)
     argv = {"binary_name", "--draft", "123"};
-    assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_SERVER));
+    assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_EMBEDDING));
 
 
     printf("test-arg-parser: test valid usage\n\n");
@@ -96,7 +96,7 @@ int main(void) {
     // --draft cannot be used outside llama-speculative
     argv = {"binary_name", "--draft", "123"};
     assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_SPECULATIVE));
-    assert(params.n_draft == 123);
+    assert(params.speculative.n_max == 123);
 
 // skip this part on windows, because setenv is not supported
 #ifdef _WIN32