]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
speculative : add heuristic algorithm (#3006)
authorLeng Yue <redacted>
Thu, 14 Sep 2023 16:14:44 +0000 (09:14 -0700)
committerGitHub <redacted>
Thu, 14 Sep 2023 16:14:44 +0000 (19:14 +0300)
* Add heuristic algo for speculative

* Constrain minimum n_draft to 2

* speculative : improve heuristic impl

* speculative : be more rewarding upon guessing max drafted tokens

* speculative : fix typos

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/speculative/speculative.cpp

index 2cd153f9a4e171a0e299225526e09511cdb415bf..aa904183fa2d8ebc77cd012972085d405e5ae09f 100644 (file)
@@ -82,7 +82,7 @@ int main(int argc, char ** argv) {
     //GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));
 
     // how many tokens to draft each time
-    const int n_draft = params.n_draft;
+    int n_draft = params.n_draft;
 
     int n_predict = 0;
     int n_drafted = 0;
@@ -131,6 +131,7 @@ int main(int argc, char ** argv) {
         LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted));
 
         int i_dft = 0;
+
         while (true) {
             // sample from the target model
             const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
@@ -174,6 +175,27 @@ int main(int argc, char ** argv) {
             llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads);
             ++n_past_dft;
 
+            // heuristic for n_draft
+            {
+                const int  n_draft_cur  = (int) drafted.size();
+                const bool all_accepted = i_dft == n_draft_cur;
+
+                LOG("n_draft      = %d\n", n_draft);
+                LOG("n_draft_cur  = %d\n", n_draft_cur);
+                LOG("i_dft        = %d\n", i_dft);
+                LOG("all_accepted = %d\n", all_accepted);
+
+                if (all_accepted && n_draft == n_draft_cur) {
+                    LOG(" - max drafted tokens accepted - n_draft += 8\n");
+                    n_draft = std::min(30, n_draft + 8);
+                } else if (all_accepted) {
+                    LOG(" - partially drafted tokens accepted - no change\n");
+                } else {
+                    LOG(" - drafted token rejected - n_draft -= 1\n");
+                    n_draft = std::max(2, n_draft - 1);
+                }
+            }
+
             drafted.clear();
             drafted.push_back(id);