params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
params_dft.n_parallel = 1;
+ // force F16 KV cache for the draft model for extra performance
+ params_dft.cache_type_k = GGML_TYPE_F16;
+ params_dft.cache_type_v = GGML_TYPE_F16;
+
llama_init_dft = common_init_from_params(params_dft);
model_dft = llama_init_dft.model.get();
cparams_dft = common_context_params_to_llama(params_dft);
cparams_dft.n_batch = n_ctx_dft;
- // force F16 KV cache for the draft model for extra performance
- cparams_dft.type_k = GGML_TYPE_F16;
- cparams_dft.type_v = GGML_TYPE_F16;
-
// the context is not needed - we will create one for each slot
llama_init_dft.context.reset();
}