if (params.n_threads_batch <= 0) {
params.n_threads_batch = std::thread::hardware_concurrency();
}
+ } else if (arg == "-td" || arg == "--threads-draft") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_threads_draft = std::stoi(argv[i]);
+ if (params.n_threads_draft <= 0) {
+ params.n_threads_draft = std::thread::hardware_concurrency();
+ }
+ } else if (arg == "-tbd" || arg == "--threads-batch-draft") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_threads_batch_draft = std::stoi(argv[i]);
+ if (params.n_threads_batch_draft <= 0) {
+ params.n_threads_batch_draft = std::thread::hardware_concurrency();
+ }
} else if (arg == "-p" || arg == "--prompt") {
if (++i >= argc) {
invalid_param = true;
printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads);
printf(" -tb N, --threads-batch N\n");
printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n");
+ printf(" -td N, --threads-draft N");
+ printf(" number of threads to use during generation (default: same as --threads)");
+ printf(" -tbd N, --threads-batch-draft N\n");
+ printf(" number of threads to use during batch and prompt processing (default: same as --threads-draft)\n");
printf(" -p PROMPT, --prompt PROMPT\n");
printf(" prompt to start generation with (default: empty)\n");
printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
uint32_t seed = -1; // RNG seed
int32_t n_threads = get_num_physical_cores();
+ int32_t n_threads_draft = -1;
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
+ int32_t n_threads_batch_draft = -1;
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
// load the draft model
params.model = params.model_draft;
params.n_gpu_layers = params.n_gpu_layers_draft;
+ if (params.n_threads_draft > 0) {
+ params.n_threads = params.n_threads_draft;
+ }
+ params.n_threads_batch = params.n_threads_batch_draft;
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
{