]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
speculative: add --n-gpu-layers-draft option (#3063)
authorFK <redacted>
Wed, 13 Sep 2023 06:50:46 +0000 (08:50 +0200)
committerGitHub <redacted>
Wed, 13 Sep 2023 06:50:46 +0000 (08:50 +0200)
common/common.cpp
common/common.h
examples/speculative/speculative.cpp

index 6e5d5b4d50757c1b71e6c5256749049edf90d786..afc9b8a55bfae60ebecee5bf1ffd536f750c966e 100644 (file)
@@ -374,6 +374,17 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
 #else
             fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
             fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
+#endif
+        } else if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
+            params.n_gpu_layers_draft = std::stoi(argv[i]);
+#else
+            fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n");
+            fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
 #endif
         } else if (arg == "--main-gpu" || arg == "-mg") {
             if (++i >= argc) {
@@ -664,6 +675,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
 #ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
     printf("  -ngl N, --n-gpu-layers N\n");
     printf("                        number of layers to store in VRAM\n");
+    printf("  -ngld N, --n-gpu-layers-draft N\n");
+    printf("                        number of layers to store in VRAM for the draft model\n");
     printf("  -ts SPLIT --tensor-split SPLIT\n");
     printf("                        how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
     printf("  -mg i, --main-gpu i   the GPU to use for scratch and small tensors\n");
index 012bf5e136f213394888bd6d2e0e4d22a2a08285..238635ae3065da6d4fa2de6d557f28b35ead94ba 100644 (file)
@@ -38,6 +38,7 @@ struct gpt_params {
     int32_t n_draft                         = 16;   // number of tokens to draft during speculative decoding
     int32_t n_chunks                        = -1;   // max number of chunks to process (-1 = unlimited)
     int32_t n_gpu_layers                    = -1;   // number of layers to store in VRAM (-1 - use default)
+    int32_t n_gpu_layers_draft              = -1;   // number of layers to store in VRAM for the draft model (-1 - use default)
     int32_t main_gpu                        = 0;    // the GPU that is used for scratch and small tensors
     float   tensor_split[LLAMA_MAX_DEVICES] = {0};  // how split tensors should be distributed across GPUs
     int32_t n_probs                         = 0;    // if greater than 0, output the probabilities of top n_probs tokens.
index 822d7b529f01d096108ac00509d836a5528d3347..2cd153f9a4e171a0e299225526e09511cdb415bf 100644 (file)
@@ -42,6 +42,7 @@ int main(int argc, char ** argv) {
 
     // load the draft model
     params.model = params.model_draft;
+    params.n_gpu_layers = params.n_gpu_layers_draft;
     std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
 
     // tokenize the prompt