]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
speculative : simplify the implementation (#10504)
authorGeorgi Gerganov <redacted>
Tue, 26 Nov 2024 10:29:38 +0000 (12:29 +0200)
committerGitHub <redacted>
Tue, 26 Nov 2024 10:29:38 +0000 (12:29 +0200)
ggml-ci

examples/speculative-simple/speculative-simple.cpp

index 7bf9056bf6db10fa9139ec8c4baa12d11081fe5b..2ea49d47c433ef8ed732fb07c9affd04d03daf35 100644 (file)
@@ -117,7 +117,8 @@ int main(int argc, char ** argv) {
     llama_token id_last = inp.back();
 
     // all tokens currently in the target context
-    auto prompt_tgt = std::vector<llama_token>(inp.begin(), inp.end() - 1);
+    llama_tokens prompt_tgt(inp.begin(), inp.end() - 1);
+    prompt_tgt.reserve(llama_n_ctx(ctx_tgt));
 
     int n_past = inp.size() - 1;
 
@@ -181,54 +182,44 @@ int main(int argc, char ** argv) {
         GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
 
         n_past    += ids.size() - 1;
-        n_drafted += batch_tgt.n_tokens - 1;
+        n_drafted += draft.size(); // note: we ignore the discarded small drafts
         n_accept  += ids.size() - 1;
+        n_predict += ids.size();
 
         // process the accepted tokens and update contexts
         //
         // this is the standard token post-processing that we normally do
         // in this case, we do it for a group of accepted tokens at once
         //
-        {
-            llama_token id;
-            std::string token_str;
-
-            for (size_t i = 0; i < ids.size(); ++i) {
-                id = ids[i];
-
-                ++n_predict;
-
-                if (llama_token_is_eog(model_tgt, id)) {
-                    has_eos = true;
-                    break;
-                }
-
-                token_str = common_token_to_piece(ctx_tgt, id);
+        for (size_t i = 0; i < ids.size(); ++i) {
+            prompt_tgt.push_back(id_last);
 
-                if (params.use_color && i + 1 < ids.size()) {
-                    LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
-                } else {
-                    LOG("%s", token_str.c_str());
-                }
-            }
+            id_last = ids[i];
 
-            if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
+            if (llama_token_is_eog(model_tgt, id_last)) {
+                has_eos = true;
                 break;
             }
 
-            LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, (int) draft.size(), id, token_str.c_str());
+            const std::string token_str = common_token_to_piece(ctx_tgt, id_last);
 
-            {
-                LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
-
-                llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
+            if (params.use_color && i + 1 < ids.size()) {
+                LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
+            } else {
+                LOG("%s", token_str.c_str());
             }
+        }
 
-            prompt_tgt.push_back(id_last);
-            prompt_tgt.insert(prompt_tgt.end(), ids.begin(), ids.end() - 1);
+        LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);
+
+        {
+            LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
+
+            llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
+        }
 
-            // remember the last accepted token for the next iteration
-            id_last = id;
+        if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
+            break;
         }
     }