]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : fix default draft model parameters (#10586)
authorGeorgi Gerganov <redacted>
Tue, 3 Dec 2024 09:20:00 +0000 (11:20 +0200)
committerGitHub <redacted>
Tue, 3 Dec 2024 09:20:00 +0000 (11:20 +0200)
* server : force F16 KV cache for the draft model

ggml-ci

* server : fix draft params

ggml-ci

* server : various params fixes

ggml-ci

examples/server/server.cpp

index 8eca14b86d517589b922ab03c295b587054810ef..2ec13d7d2f53613824e1ae0a039a1ce2409f4094 100644 (file)
@@ -696,8 +696,9 @@ struct server_context {
 
             params_dft.devices      = params_base.speculative.devices;
             params_dft.model        = params_base.speculative.model;
-            params_dft.n_ctx        = params_base.speculative.n_ctx;
+            params_dft.n_ctx        = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
             params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
+            params_dft.n_parallel   = 1;
 
             common_init_result llama_init_dft = common_init_from_params(params_dft);
 
@@ -717,8 +718,14 @@ struct server_context {
                 return false;
             }
 
-            cparams_dft = common_context_params_to_llama(params_base);
-            cparams_dft.n_batch = llama_n_ctx(llama_init_dft.context);
+            const int n_ctx_dft = llama_n_ctx(llama_init_dft.context);
+
+            cparams_dft = common_context_params_to_llama(params_dft);
+            cparams_dft.n_batch = n_ctx_dft;
+
+            // force F16 KV cache for the draft model for extra performance
+            cparams_dft.type_k = GGML_TYPE_F16;
+            cparams_dft.type_v = GGML_TYPE_F16;
 
             // the context is not needed - we will create one for each slot
             llama_free(llama_init_dft.context);
@@ -2322,6 +2329,10 @@ struct server_context {
                     continue;
                 }
 
+                if (slot.state != SLOT_STATE_GENERATING) {
+                    continue;
+                }
+
                 llama_token id = slot.sampled;
 
                 struct common_speculative_params params_spec;