]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
speculative : threading options (#4959)
authorstduhpf <redacted>
Tue, 16 Jan 2024 11:04:32 +0000 (12:04 +0100)
committerGitHub <redacted>
Tue, 16 Jan 2024 11:04:32 +0000 (13:04 +0200)
* speculative: expose draft threading

* fix usage format

* accept -td and -tbd args

* speculative: revert default behavior when -td is unspecified

* fix trailing whitespace

common/common.cpp
common/common.h
examples/speculative/speculative.cpp

index c11006bcb91755a5518098a7dd029ceaadb0e874..2b0865fff0e62d27713fda3993b1c88b7d14e103 100644 (file)
@@ -167,6 +167,24 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
             if (params.n_threads_batch <= 0) {
                 params.n_threads_batch = std::thread::hardware_concurrency();
             }
+        } else if (arg == "-td" || arg == "--threads-draft") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.n_threads_draft = std::stoi(argv[i]);
+            if (params.n_threads_draft <= 0) {
+                params.n_threads_draft = std::thread::hardware_concurrency();
+            }
+        } else if (arg == "-tbd" || arg == "--threads-batch-draft") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.n_threads_batch_draft = std::stoi(argv[i]);
+            if (params.n_threads_batch_draft <= 0) {
+                params.n_threads_batch_draft = std::thread::hardware_concurrency();
+            }
         } else if (arg == "-p" || arg == "--prompt") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -845,6 +863,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  -t N, --threads N     number of threads to use during generation (default: %d)\n", params.n_threads);
     printf("  -tb N, --threads-batch N\n");
     printf("                        number of threads to use during batch and prompt processing (default: same as --threads)\n");
+    printf("  -td N, --threads-draft N");
+    printf("                        number of threads to use during generation (default: same as --threads)");
+    printf("  -tbd N, --threads-batch-draft N\n");
+    printf("                        number of threads to use during batch and prompt processing (default: same as --threads-draft)\n");
     printf("  -p PROMPT, --prompt PROMPT\n");
     printf("                        prompt to start generation with (default: empty)\n");
     printf("  -e, --escape          process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
index 096468243d88c1df2b557cc5c59d7a2309f91d22..1f43e6282f48dcf0032bb727b1fdf2bdea3946db 100644 (file)
@@ -46,7 +46,9 @@ struct gpt_params {
     uint32_t seed                           = -1;    // RNG seed
 
     int32_t n_threads                       = get_num_physical_cores();
+    int32_t n_threads_draft                 = -1;
     int32_t n_threads_batch                 = -1;    // number of threads to use for batch processing (-1 = use n_threads)
+    int32_t n_threads_batch_draft           = -1;
     int32_t n_predict                       = -1;    // new tokens to predict
     int32_t n_ctx                           = 512;   // context size
     int32_t n_batch                         = 512;   // batch size for prompt processing (must be >=32 to use BLAS)
index 20f1fb5bfcd99392efa2a8855825053ccdd071a1..7b3af01f339a9a29133e73be779fc343f2dbc4c7 100644 (file)
@@ -65,6 +65,10 @@ int main(int argc, char ** argv) {
     // load the draft model
     params.model = params.model_draft;
     params.n_gpu_layers = params.n_gpu_layers_draft;
+    if (params.n_threads_draft > 0) {
+        params.n_threads = params.n_threads_draft;
+    }
+    params.n_threads_batch = params.n_threads_batch_draft;
     std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
 
     {