]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
speculative : fix handling of some input params (#9963)
authorGeorgi Gerganov <redacted>
Mon, 21 Oct 2024 06:37:12 +0000 (09:37 +0300)
committerGitHub <redacted>
Mon, 21 Oct 2024 06:37:12 +0000 (09:37 +0300)
* speculative : fix batch sizes at initialization

ggml-ci

* speculative : handle params.n_predict == -1

* speculative : limit batch size to llama_n_batch

examples/speculative/speculative.cpp

index b201bd714a447c9e95cd2e7bcafd90d9cd3403ea..8a64754151719673c1b2c1e7cc84012487b991b7 100644 (file)
@@ -39,6 +39,11 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
+    if (params.n_predict < -1) {
+        LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
+        return 1;
+    }
+
     common_init();
 
     if (params.model_draft.empty()) {
@@ -190,8 +195,8 @@ int main(int argc, char ** argv) {
         drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
     }
 
-    llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
-    llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
+    llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
+    llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft);
 
     const auto t_dec_start = ggml_time_us();
 
@@ -441,7 +446,7 @@ int main(int argc, char ** argv) {
             ++n_past_dft;
         }
 
-        if (n_predict > params.n_predict || has_eos) {
+        if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
             break;
         }