]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
speculative : bug fixes
authorGeorgi Gerganov <redacted>
Wed, 18 Oct 2023 15:49:40 +0000 (18:49 +0300)
committerGeorgi Gerganov <redacted>
Wed, 18 Oct 2023 15:49:40 +0000 (18:49 +0300)
examples/speculative/speculative.cpp

index 53f42fad8233b63f76627306099db09eab78a09a..24f49012a4baaa5056c86ef0b19efcf913fd7575 100644 (file)
@@ -37,8 +37,8 @@ int main(int argc, char ** argv) {
     const int n_seq_dft = params.n_parallel;
 
     // TODO: make this configurable
-    const float p_accept = 0.4f;
-    const float p_split  = 0.3f;
+    const float p_accept = 0.80f;
+    const float p_split  = 0.10f;
 
 #ifndef LOG_DISABLE_LOGS
     log_set_target(log_filename_generator("speculative", "log"));
@@ -118,7 +118,7 @@ int main(int argc, char ** argv) {
     std::vector<seq_draft> drafts(n_seq_dft);
 
     params.grammar.clear();             // the draft samplers will copy the target sampler's grammar
-    params.sampling_params.temp = 1.0f; // the draft samplers use default temperature
+    params.sampling_params.temp = std::max(0.01f, params.sampling_params.temp);
 
     for (int s = 0; s < n_seq_dft; ++s) {
         drafts[s].ctx_sampling = llama_sampling_init(params);
@@ -156,7 +156,7 @@ int main(int argc, char ** argv) {
 
             llama_sampling_accept(ctx_sampling, ctx_tgt, id);
 
-            //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens));
+            //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
 
             const std::string token_str = llama_token_to_piece(ctx_tgt, id);
 
@@ -202,7 +202,7 @@ int main(int argc, char ** argv) {
 
             // TODO: simplify
             {
-                LOG("keeping sequence %d\n", s_keep);
+                LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
 
                 llama_kv_cache_seq_keep(ctx_dft, s_keep);
                 llama_kv_cache_seq_cp  (ctx_dft, s_keep, 0, -1, -1);
@@ -277,7 +277,7 @@ int main(int argc, char ** argv) {
                 }
 
                 if (cur_p[0].p < p_accept) {
-                    LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p);
+                    LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept);
                     drafts[s].drafting = false;
                     continue;
                 }
@@ -337,16 +337,14 @@ int main(int argc, char ** argv) {
 
                     llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
 
-                    // no need to evaluate the last drafted token, since we won't use the result
-                    if (batch_tgt.n_tokens > n_draft) {
-                        drafts[s].drafting = false;
-                        continue;
-                    }
-
                     // add the token to the batch for batched decoding with the draft model
                     drafts[s].i_batch_dft = batch_dft.n_tokens;
 
                     llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
+
+                    if (batch_tgt.n_tokens > n_draft) {
+                        drafts[s].drafting = false;
+                    }
                 }
             }
 
@@ -365,11 +363,6 @@ int main(int argc, char ** argv) {
             }
         }
 
-        // account for the last drafted token that we didn't evaluate
-        if (batch_tgt.n_tokens > n_draft) {
-            ++n_drafted;
-        }
-
         // evaluate the target model on the drafted tokens
         {
             llama_kv_cache_seq_keep(ctx_tgt, 0);