]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
speculative : change default p_accept to 0.5 + CLI args (#3919)
authorGeorgi Gerganov <redacted>
Fri, 3 Nov 2023 07:41:17 +0000 (09:41 +0200)
committerGeorgi Gerganov <redacted>
Fri, 3 Nov 2023 07:41:56 +0000 (09:41 +0200)
ggml-ci

common/common.cpp
common/common.h
examples/speculative/speculative.cpp

index e938dee165d9da13c15b3851ec032a3a79a29eed..20cc4a081b22253e4dd3ecfe126d0eb3c9213007 100644 (file)
@@ -403,6 +403,18 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             params.n_sequences = std::stoi(argv[i]);
+        } else if (arg == "--p-accept" || arg == "-pa") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.p_accept = std::stof(argv[i]);
+        } else if (arg == "--p-split" || arg == "-ps") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.p_split = std::stof(argv[i]);
         } else if (arg == "-m" || arg == "--model") {
             if (++i >= argc) {
                 invalid_param = true;
@@ -778,6 +790,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("  --chunks N            max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
     printf("  -np N, --parallel N   number of parallel sequences to decode (default: %d)\n", params.n_parallel);
     printf("  -ns N, --sequences N  number of sequences to decode (default: %d)\n", params.n_sequences);
+    printf("  -pa N, --p-accept N   speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept);
+    printf("  -ps N, --p-split N    speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
     printf("  -cb, --cont-batching  enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
     printf("  --mmproj MMPROJ_FILE  path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
     printf("  --image IMAGE_FILE    path to an image file. use with multimodal models\n");
index 9ad62563302872a8562835d81234a6e54300578a..dd6b002eb94ba27f9bbd15c4d5172ee687c8bb8f 100644 (file)
@@ -44,6 +44,7 @@ int32_t get_num_physical_cores();
 
 struct gpt_params {
     uint32_t seed                           = -1;    // RNG seed
+
     int32_t n_threads                       = get_num_physical_cores();
     int32_t n_threads_batch                 = -1;    // number of threads to use for batch processing (-1 = use n_threads)
     int32_t n_predict                       = -1;    // new tokens to predict
@@ -54,6 +55,8 @@ struct gpt_params {
     int32_t n_chunks                        = -1;    // max number of chunks to process (-1 = unlimited)
     int32_t n_parallel                      = 1;     // number of parallel sequences to decode
     int32_t n_sequences                     = 1;     // number of sequences to decode
+    float   p_accept                        = 0.5f;  // speculative decoding accept probability
+    float   p_split                         = 0.1f;  // speculative decoding split probability
     int32_t n_gpu_layers                    = -1;    // number of layers to store in VRAM (-1 - use default)
     int32_t n_gpu_layers_draft              = -1;    // number of layers to store in VRAM for the draft model (-1 - use default)
     int32_t main_gpu                        = 0;     // the GPU that is used for scratch and small tensors
@@ -66,7 +69,8 @@ struct gpt_params {
     float   yarn_beta_fast                  = 32.0f; // YaRN low correction dim
     float   yarn_beta_slow                  = 1.0f;  // YaRN high correction dim
     int32_t yarn_orig_ctx                   = 0;     // YaRN original context length
-    int8_t  rope_scaling_type               = LLAMA_ROPE_SCALING_UNSPECIFIED;
+    int8_t  rope_scaling_type               = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment
+                                                                              //       pinging @cebtenzzre
 
     // // sampling parameters
     struct llama_sampling_params sparams;
@@ -90,7 +94,7 @@ struct gpt_params {
     int  ppl_output_type   = 0;     // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
                                     //                                       (which is more convenient to use for plotting)
                                     //
-    bool hellaswag         = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
+    bool   hellaswag       = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
     size_t hellaswag_tasks = 400;   // number of tasks to use when computing the HellaSwag score
 
     bool mul_mat_q         = true;  // if true, use mul_mat_q kernels instead of cuBLAS
index 798684f66678e2c48fef7c1f7708f6aebe9bf053..3a8e278110c20ec7ce509b5ef8fa5139caceb052 100644 (file)
@@ -37,9 +37,11 @@ int main(int argc, char ** argv) {
     // max number of parallel drafting sequences (i.e. tree branches)
     const int n_seq_dft = params.n_parallel;
 
-    // TODO: make this configurable
-    const float p_accept = 0.80f;
-    const float p_split  = 0.10f;
+    // probability threshold for accepting a token from the draft model
+    const float p_accept = params.p_accept;
+
+    // probability threshold for splitting a draft branch (only for n_seq_dft > 1)
+    const float p_split  = params.p_split;
 
 #ifndef LOG_DISABLE_LOGS
     log_set_target(log_filename_generator("speculative", "log"));