]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : add server parameters for draft model cache type (#13782)
authoraa956 <redacted>
Thu, 19 Jun 2025 13:01:03 +0000 (16:01 +0300)
committerGitHub <redacted>
Thu, 19 Jun 2025 13:01:03 +0000 (16:01 +0300)
Co-authored-by: aa956 <redacted>
common/arg.cpp
common/common.h
tools/server/README.md
tools/server/server.cpp

index 231de227a9122815f2e28489dcb407095d6738bb..3dfaa71eff18806e396c8d3f8e3a9f1f4fb95a31 100644 (file)
@@ -3210,6 +3210,32 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.speculative.model.path = value;
         }
     ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));
+    add_opt(common_arg(
+        {"-ctkd", "--cache-type-k-draft"}, "TYPE",
+        string_format(
+            "KV cache data type for K for the draft model\n"
+            "allowed values: %s\n"
+            "(default: %s)",
+            get_all_kv_cache_types().c_str(),
+            ggml_type_name(params.speculative.cache_type_k)
+        ),
+        [](common_params & params, const std::string & value) {
+            params.speculative.cache_type_k = kv_cache_type_from_str(value);
+        }
+    ).set_env("LLAMA_ARG_CACHE_TYPE_K_DRAFT"));
+    add_opt(common_arg(
+        {"-ctvd", "--cache-type-v-draft"}, "TYPE",
+        string_format(
+            "KV cache data type for V for the draft model\n"
+            "allowed values: %s\n"
+            "(default: %s)",
+            get_all_kv_cache_types().c_str(),
+            ggml_type_name(params.speculative.cache_type_v)
+        ),
+        [](common_params & params, const std::string & value) {
+            params.speculative.cache_type_v = kv_cache_type_from_str(value);
+        }
+    ).set_env("LLAMA_ARG_CACHE_TYPE_V_DRAFT"));
 
     add_opt(common_arg(
         {"-mv", "--model-vocoder"}, "FNAME",
index 00b6ca03a20b4b42634f8ef67e041f992694089b..5710c4e9735fdb524eb02ce37188c013b8f51c61 100644 (file)
@@ -199,6 +199,9 @@ struct common_params_speculative {
     float   p_split      =  0.1f; // speculative decoding split probability
     float   p_min        = 0.75f; // minimum speculative decoding probability (greedy)
 
+    ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
+    ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
+
     struct cpu_params cpuparams;
     struct cpu_params cpuparams_batch;
 
index 06533c172e530575cce579554b3df2d5c8e15ed6..43aa65d50ce3f2858e1ae85000ad0724e7f15f6c 100644 (file)
@@ -187,6 +187,8 @@ The project is under active development, and we are [looking for feedback and co
 | `-devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
 | `-ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | number of layers to store in VRAM for the draft model<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
 | `-md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_MODEL_DRAFT) |
+| `-ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for speculative decoding model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K_DRAFT) |
+| `-ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for speculative decoding model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_V_DRAFT) |
 | `-mv, --model-vocoder FNAME` | vocoder model for audio generation (default: unused) |
 | `--tts-use-guide-tokens` | Use guide tokens to improve TTS word recall |
 | `--embd-bge-small-en-default` | use default bge-small-en-v1.5 model (note: can download weights from the internet) |
index 721d09182845d4bc0bc2c893807958950ef46205..9d55b3338bcfe07863a00695173278e50dc1ede5 100644 (file)
@@ -1969,10 +1969,8 @@ struct server_context {
             params_dft.n_ctx        = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
             params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
             params_dft.n_parallel   = 1;
-
-            // force F16 KV cache for the draft model for extra performance
-            params_dft.cache_type_k = GGML_TYPE_F16;
-            params_dft.cache_type_v = GGML_TYPE_F16;
+            params_dft.cache_type_k = params_base.speculative.cache_type_k;
+            params_dft.cache_type_v = params_base.speculative.cache_type_v;
 
             llama_init_dft = common_init_from_params(params_dft);