]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : enable /slots by default and make it secure (#15630)
authorGeorgi Gerganov <redacted>
Sun, 31 Aug 2025 17:11:58 +0000 (20:11 +0300)
committerGitHub <redacted>
Sun, 31 Aug 2025 17:11:58 +0000 (20:11 +0300)
* server : enable /slots by default and make it secure

ggml-ci

* server : fix tests to pass `--no-slots` when necessary

* server : extend /props with info about enabled endpoints

common/arg.cpp
common/common.h
tools/server/README.md
tools/server/server.cpp
tools/server/tests/utils.py

index 72c69c39a0fe199bdedc89fbe2ef184ae6a83e45..4fa214d3d28569cf9d74e6dc00b38ec4bb671e86 100644 (file)
@@ -2962,13 +2962,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.endpoint_metrics = true;
         }
     ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_METRICS"));
-    add_opt(common_arg(
-        {"--slots"},
-        string_format("enable slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"),
-        [](common_params & params) {
-            params.endpoint_slots = true;
-        }
-    ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_SLOTS"));
     add_opt(common_arg(
         {"--props"},
         string_format("enable changing global properties via POST /props (default: %s)", params.endpoint_props ? "enabled" : "disabled"),
@@ -2976,6 +2969,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.endpoint_props = true;
         }
     ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_PROPS"));
+    add_opt(common_arg(
+        {"--slots"},
+        string_format("enable slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"),
+        [](common_params & params) {
+            params.endpoint_slots = true;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_SLOTS"));
     add_opt(common_arg(
         {"--no-slots"},
         "disables slots monitoring endpoint",
index 02ca093bdf8b7c7269bad7524fda147f3f12fa42..85b3b879d45362962d8f22b0161d85a3c0f97b7b 100644 (file)
@@ -444,7 +444,7 @@ struct common_params {
 
     // "advanced" endpoints are disabled by default for better security
     bool webui            = true;
-    bool endpoint_slots   = false;
+    bool endpoint_slots   = true;
     bool endpoint_props   = false; // only control POST requests, not GET
     bool endpoint_metrics = false;
 
index b7285b231992ea7ec9d65376c0064cb6052368d2..b0527f3cbea285dacae84ed3dd71827d9c07e2b1 100644 (file)
@@ -37,7 +37,7 @@ The project is under active development, and we are [looking for feedback and co
 | `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") |
 | `-Cr, --cpu-range lo-hi` | range of CPUs for affinity. Complements --cpu-mask |
 | `--cpu-strict <0\|1>` | use strict CPU placement (default: 0)<br/> |
-| `--prio N` | set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: 0)<br/> |
+| `--prio N` | set process/thread priority : low(-1), normal(0), medium(1), high(2), realtime(3) (default: 0)<br/> |
 | `--poll <0...100>` | use polling level to wait for work (0 - no polling, default: 50)<br/> |
 | `-Cb, --cpu-mask-batch M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range-batch (default: same as --cpu-mask) |
 | `-Crb, --cpu-range-batch lo-hi` | ranges of CPUs for affinity. Complements --cpu-mask-batch |
@@ -49,6 +49,8 @@ The project is under active development, and we are [looking for feedback and co
 | `-b, --batch-size N` | logical maximum batch size (default: 2048)<br/>(env: LLAMA_ARG_BATCH) |
 | `-ub, --ubatch-size N` | physical maximum batch size (default: 512)<br/>(env: LLAMA_ARG_UBATCH) |
 | `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) |
+| `--swa-full` | use full-size SWA cache (default: false)<br/>[(more info)](https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)<br/>(env: LLAMA_ARG_SWA_FULL) |
+| `--kv-unified, -kvu` | use single unified KV buffer for the KV cache of all sequences (default: false)<br/>[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)<br/>(env: LLAMA_ARG_KV_SPLIT) |
 | `-fa, --flash-attn` | enable Flash Attention (default: disabled)<br/>(env: LLAMA_ARG_FLASH_ATTN) |
 | `--no-perf` | disable internal libllama performance timings (default: false)<br/>(env: LLAMA_ARG_NO_PERF) |
 | `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) |
@@ -63,6 +65,7 @@ The project is under active development, and we are [looking for feedback and co
 | `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: 1.0)<br/>(env: LLAMA_ARG_YARN_BETA_SLOW) |
 | `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: 32.0)<br/>(env: LLAMA_ARG_YARN_BETA_FAST) |
 | `-nkvo, --no-kv-offload` | disable KV offload<br/>(env: LLAMA_ARG_NO_KV_OFFLOAD) |
+| `-nr, --no-repack` | disable weight repacking<br/>(env: LLAMA_ARG_NO_REPACK) |
 | `-ctk, --cache-type-k TYPE` | KV cache data type for K<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K) |
 | `-ctv, --cache-type-v TYPE` | KV cache data type for V<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_V) |
 | `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
@@ -73,12 +76,15 @@ The project is under active development, and we are [looking for feedback and co
 | `-dev, --device <dev1,dev2,..>` | comma-separated list of devices to use for offloading (none = don't offload)<br/>use --list-devices to see a list of available devices<br/>(env: LLAMA_ARG_DEVICE) |
 | `--list-devices` | print list of available devices and exit |
 | `--override-tensor, -ot <tensor name pattern>=<buffer type>,...` | override tensor buffer type |
+| `--cpu-moe, -cmoe` | keep all Mixture of Experts (MoE) weights in the CPU<br/>(env: LLAMA_ARG_CPU_MOE) |
+| `--n-cpu-moe, -ncmoe N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU<br/>(env: LLAMA_ARG_N_CPU_MOE) |
 | `-ngl, --gpu-layers, --n-gpu-layers N` | number of layers to store in VRAM<br/>(env: LLAMA_ARG_N_GPU_LAYERS) |
 | `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:<br/>- none: use one GPU only<br/>- layer (default): split layers and KV across GPUs<br/>- row: split rows across GPUs<br/>(env: LLAMA_ARG_SPLIT_MODE) |
 | `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1<br/>(env: LLAMA_ARG_TENSOR_SPLIT) |
 | `-mg, --main-gpu INDEX` | the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: 0)<br/>(env: LLAMA_ARG_MAIN_GPU) |
 | `--check-tensors` | check model tensor data for invalid values (default: false) |
 | `--override-kv KEY=TYPE:VALUE` | advanced option to override model metadata by key. may be specified multiple times.<br/>types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false |
+| `--no-op-offload` | disable offloading host tensor operations to device (default: false) |
 | `--lora FNAME` | path to LoRA adapter (can be repeated to use multiple adapters) |
 | `--lora-scaled FNAME SCALE` | path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters) |
 | `--control-vector FNAME` | add a control vector<br/>note: this argument can be repeated to add multiple control vectors |
@@ -96,9 +102,12 @@ The project is under active development, and we are [looking for feedback and co
 | `--log-file FNAME` | Log to file |
 | `--log-colors` | Enable colored logging<br/>(env: LLAMA_LOG_COLORS) |
 | `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) |
+| `--offline` | Offline mode: forces use of cache, prevents network access<br/>(env: LLAMA_OFFLINE) |
 | `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored.<br/>(env: LLAMA_LOG_VERBOSITY) |
 | `--log-prefix` | Enable prefix in log messages<br/>(env: LLAMA_LOG_PREFIX) |
 | `--log-timestamps` | Enable timestamps in log messages<br/>(env: LLAMA_LOG_TIMESTAMPS) |
+| `-ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for the draft model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K_DRAFT) |
+| `-ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for the draft model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_V_DRAFT) |
 
 
 **Sampling params**
@@ -113,6 +122,7 @@ The project is under active development, and we are [looking for feedback and co
 | `--top-k N` | top-k sampling (default: 40, 0 = disabled) |
 | `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) |
 | `--min-p N` | min-p sampling (default: 0.1, 0.0 = disabled) |
+| `--top-nsigma N` | top-n-sigma sampling (default: -1.0, -1.0 = disabled) |
 | `--xtc-probability N` | xtc probability (default: 0.0, 0.0 = disabled) |
 | `--xtc-threshold N` | xtc threshold (default: 0.1, 1.0 = disabled) |
 | `--typical N` | locally typical sampling, parameter p (default: 1.0, 1.0 = disabled) |
@@ -141,7 +151,10 @@ The project is under active development, and we are [looking for feedback and co
 
 | Argument | Explanation |
 | -------- | ----------- |
-| `--no-context-shift` | disables context shift on infinite text generation (default: disabled)<br/>(env: LLAMA_ARG_NO_CONTEXT_SHIFT) |
+| `--swa-checkpoints N` | max number of SWA checkpoints per slot to create (default: 3)<br/>[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)<br/>(env: LLAMA_ARG_SWA_CHECKPOINTS) |
+| `--no-context-shift` | disables context shift on infinite text generation (default: enabled)<br/>(env: LLAMA_ARG_NO_CONTEXT_SHIFT) |
+| `--context-shift` | enables context shift on infinite text generation (default: disabled)<br/>(env: LLAMA_ARG_CONTEXT_SHIFT) |
+| `-r, --reverse-prompt PROMPT` | halt generation at PROMPT, return control in interactive mode<br/> |
 | `-sp, --special` | special tokens output enabled (default: false) |
 | `--no-warmup` | skip warming up the model with an empty run |
 | `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) |
@@ -152,10 +165,14 @@ The project is under active development, and we are [looking for feedback and co
 | `--mmproj-url URL` | URL to a multimodal projector file. see tools/mtmd/README.md<br/>(env: LLAMA_ARG_MMPROJ_URL) |
 | `--no-mmproj` | explicitly disable multimodal projector, useful when using -hf<br/>(env: LLAMA_ARG_NO_MMPROJ) |
 | `--no-mmproj-offload` | do not offload multimodal projector to GPU<br/>(env: LLAMA_ARG_NO_MMPROJ_OFFLOAD) |
+| `--override-tensor-draft, -otd <tensor name pattern>=<buffer type>,...` | override tensor buffer type for draft model |
+| `--cpu-moe-draft, -cmoed` | keep all Mixture of Experts (MoE) weights in the CPU for the draft model<br/>(env: LLAMA_ARG_CPU_MOE_DRAFT) |
+| `--n-cpu-moe-draft, -ncmoed N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model<br/>(env: LLAMA_ARG_N_CPU_MOE_DRAFT) |
 | `-a, --alias STRING` | set alias for model name (to be used by REST API)<br/>(env: LLAMA_ARG_ALIAS) |
 | `--host HOST` | ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: 127.0.0.1)<br/>(env: LLAMA_ARG_HOST) |
 | `--port PORT` | port to listen (default: 8080)<br/>(env: LLAMA_ARG_PORT) |
 | `--path PATH` | path to serve static files from (default: )<br/>(env: LLAMA_ARG_STATIC_PATH) |
+| `--api-prefix PREFIX` | prefix path the server serves from, without the trailing slash (default: )<br/>(env: LLAMA_ARG_API_PREFIX) |
 | `--no-webui` | Disable the Web UI (default: enabled)<br/>(env: LLAMA_ARG_NO_WEBUI) |
 | `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)<br/>(env: LLAMA_ARG_EMBEDDINGS) |
 | `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)<br/>(env: LLAMA_ARG_RERANKING) |
@@ -163,23 +180,25 @@ The project is under active development, and we are [looking for feedback and co
 | `--api-key-file FNAME` | path to file containing API keys (default: none) |
 | `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key<br/>(env: LLAMA_ARG_SSL_KEY_FILE) |
 | `--ssl-cert-file FNAME` | path to file a PEM-encoded SSL certificate<br/>(env: LLAMA_ARG_SSL_CERT_FILE) |
-| `--chat-template-kwargs STRING` | JSON object containing additional params for the json template parser. Example: `--chat_template_kwargs "{\"enable_thinking\":false}`"<br/>(env: LLAMA_CHAT_TEMPLATE_KWARGS) |
+| `--chat-template-kwargs STRING` | sets additional params for the json template parser<br/>(env: LLAMA_CHAT_TEMPLATE_KWARGS) |
 | `-to, --timeout N` | server read/write timeout in seconds (default: 600)<br/>(env: LLAMA_ARG_TIMEOUT) |
 | `--threads-http N` | number of threads used to process HTTP requests (default: -1)<br/>(env: LLAMA_ARG_THREADS_HTTP) |
 | `--cache-reuse N` | min chunk size to attempt reusing from the cache via KV shifting (default: 0)<br/>[(card)](https://ggml.ai/f0.png)<br/>(env: LLAMA_ARG_CACHE_REUSE) |
 | `--metrics` | enable prometheus compatible metrics endpoint (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_METRICS) |
-| `--slots` | enable slots monitoring endpoint (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_SLOTS) |
 | `--props` | enable changing global properties via POST /props (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_PROPS) |
+| `--slots` | enable slots monitoring endpoint (default: enabled)<br/>(env: LLAMA_ARG_ENDPOINT_SLOTS) |
 | `--no-slots` | disables slots monitoring endpoint<br/>(env: LLAMA_ARG_NO_ENDPOINT_SLOTS) |
 | `--slot-save-path PATH` | path to save slot kv cache (default: disabled) |
 | `--jinja` | use jinja template for chat (default: disabled)<br/>(env: LLAMA_ARG_JINJA) |
-| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:<br/>- none: leaves thoughts unparsed in `message.content`<br/>- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)<br/>(default: deepseek)<br/>(env: LLAMA_ARG_THINK) |
+| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:<br/>- none: leaves thoughts unparsed in `message.content`<br/>- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)<br/>(default: auto)<br/>(env: LLAMA_ARG_THINK) |
 | `--reasoning-budget N` | controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)<br/>(env: LLAMA_ARG_THINK_BUDGET) |
-| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, falcon3, gemma, gigachat, glmedge, granite, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, smolvlm, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE) |
-| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, falcon3, gemma, gigachat, glmedge, granite, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, smolvlm, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) |
-| `--no-prefill-assistant` | whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)<br/>when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled<br/>(env: LLAMA_ARG_NO_PREFILL_ASSISTANT) |
+| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE) |
+| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) |
+| `--no-prefill-assistant` | whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)<br/>when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled<br/><br/>(env: LLAMA_ARG_NO_PREFILL_ASSISTANT) |
 | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)<br/> |
 | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) |
+| `-td, --threads-draft N` | number of threads to use during generation (default: same as --threads) |
+| `-tbd, --threads-batch-draft N` | number of threads to use during batch and prompt processing (default: same as --threads-draft) |
 | `--draft-max, --draft, --draft-n N` | number of tokens to draft for speculative decoding (default: 16)<br/>(env: LLAMA_ARG_DRAFT_MAX) |
 | `--draft-min, --draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 0)<br/>(env: LLAMA_ARG_DRAFT_MIN) |
 | `--draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.8)<br/>(env: LLAMA_ARG_DRAFT_P_MIN) |
@@ -187,8 +206,7 @@ The project is under active development, and we are [looking for feedback and co
 | `-devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
 | `-ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | number of layers to store in VRAM for the draft model<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
 | `-md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_MODEL_DRAFT) |
-| `-ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for speculative decoding model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K_DRAFT) |
-| `-ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for speculative decoding model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_V_DRAFT) |
+| `--spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible |
 | `-mv, --model-vocoder FNAME` | vocoder model for audio generation (default: unused) |
 | `--tts-use-guide-tokens` | Use guide tokens to improve TTS word recall |
 | `--embd-bge-small-en-default` | use default bge-small-en-v1.5 model (note: can download weights from the internet) |
@@ -199,6 +217,7 @@ The project is under active development, and we are [looking for feedback and co
 | `--fim-qwen-7b-default` | use default Qwen 2.5 Coder 7B (note: can download weights from the internet) |
 | `--fim-qwen-7b-spec` | use Qwen 2.5 Coder 7B + 0.5B draft for speculative decoding (note: can download weights from the internet) |
 | `--fim-qwen-14b-spec` | use Qwen 2.5 Coder 14B + 0.5B draft for speculative decoding (note: can download weights from the internet) |
+| `--fim-qwen-30b-default` | use default Qwen 3 Coder 30B A3B Instruct (note: can download weights from the internet) |
 
 
 Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var.
@@ -865,25 +884,23 @@ Same as the `/v1/embeddings` endpoint.
 
 ### GET `/slots`: Returns the current slots processing state
 
-> [!WARNING]
-> This endpoint is intended for debugging and may be modified in future versions. For security reasons, we strongly advise against enabling it in production environments.
-
-This endpoint is disabled by default and can be enabled with `--slots`
+This endpoint is enabled by default and can be disabled with `--no-slots`. It can be used to query various per-slot metrics, such as speed, processed tokens, sampling parameters, etc.
 
 If query param `?fail_on_no_slot=1` is set, this endpoint will respond with status code 503 if there is no available slots.
 
 **Response format**
 
-Example:
+<details>
+<summary>Example with 2 slots</summary>
 
 ```json
 [
   {
     "id": 0,
-    "id_task": -1,
-    "n_ctx": 1024,
+    "id_task": 135,
+    "n_ctx": 65536,
     "speculative": false,
-    "is_processing": false,
+    "is_processing": true,
     "params": {
       "n_predict": -1,
       "seed": 4294967295,
@@ -893,6 +910,7 @@ Example:
       "top_k": 40,
       "top_p": 0.949999988079071,
       "min_p": 0.05000000074505806,
+      "top_n_sigma": -1.0,
       "xtc_probability": 0.0,
       "xtc_threshold": 0.10000000149011612,
       "typical_p": 1.0,
@@ -903,17 +921,10 @@ Example:
       "dry_multiplier": 0.0,
       "dry_base": 1.75,
       "dry_allowed_length": 2,
-      "dry_penalty_last_n": -1,
-      "dry_sequence_breakers": [
-        "\n",
-        ":",
-        "\"",
-        "*"
-      ],
+      "dry_penalty_last_n": 131072,
       "mirostat": 0,
       "mirostat_tau": 5.0,
       "mirostat_eta": 0.10000000149011612,
-      "stop": [],
       "max_tokens": -1,
       "n_keep": 0,
       "n_discard": 0,
@@ -921,8 +932,12 @@ Example:
       "stream": true,
       "n_probs": 0,
       "min_keep": 0,
-      "grammar": "",
+      "chat_format": "GPT-OSS",
+      "reasoning_format": "none",
+      "reasoning_in_content": false,
+      "thinking_forced_open": false,
       "samplers": [
+        "penalties",
         "dry",
         "top_k",
         "typ_p",
@@ -932,22 +947,89 @@ Example:
         "temperature"
       ],
       "speculative.n_max": 16,
-      "speculative.n_min": 5,
-      "speculative.p_min": 0.8999999761581421,
-      "timings_per_token": false
+      "speculative.n_min": 0,
+      "speculative.p_min": 0.75,
+      "timings_per_token": false,
+      "post_sampling_probs": false,
+      "lora": []
     },
-    "prompt": "",
     "next_token": {
       "has_next_token": true,
       "has_new_line": false,
       "n_remain": -1,
-      "n_decoded": 0,
-      "stopping_word": ""
+      "n_decoded": 0
+    }
+  },
+  {
+    "id": 1,
+    "id_task": 0,
+    "n_ctx": 65536,
+    "speculative": false,
+    "is_processing": true,
+    "params": {
+      "n_predict": -1,
+      "seed": 4294967295,
+      "temperature": 0.800000011920929,
+      "dynatemp_range": 0.0,
+      "dynatemp_exponent": 1.0,
+      "top_k": 40,
+      "top_p": 0.949999988079071,
+      "min_p": 0.05000000074505806,
+      "top_n_sigma": -1.0,
+      "xtc_probability": 0.0,
+      "xtc_threshold": 0.10000000149011612,
+      "typical_p": 1.0,
+      "repeat_last_n": 64,
+      "repeat_penalty": 1.0,
+      "presence_penalty": 0.0,
+      "frequency_penalty": 0.0,
+      "dry_multiplier": 0.0,
+      "dry_base": 1.75,
+      "dry_allowed_length": 2,
+      "dry_penalty_last_n": 131072,
+      "mirostat": 0,
+      "mirostat_tau": 5.0,
+      "mirostat_eta": 0.10000000149011612,
+      "max_tokens": -1,
+      "n_keep": 0,
+      "n_discard": 0,
+      "ignore_eos": false,
+      "stream": true,
+      "n_probs": 0,
+      "min_keep": 0,
+      "chat_format": "GPT-OSS",
+      "reasoning_format": "none",
+      "reasoning_in_content": false,
+      "thinking_forced_open": false,
+      "samplers": [
+        "penalties",
+        "dry",
+        "top_k",
+        "typ_p",
+        "top_p",
+        "min_p",
+        "xtc",
+        "temperature"
+      ],
+      "speculative.n_max": 16,
+      "speculative.n_min": 0,
+      "speculative.p_min": 0.75,
+      "timings_per_token": false,
+      "post_sampling_probs": false,
+      "lora": []
+    },
+    "next_token": {
+      "has_next_token": true,
+      "has_new_line": true,
+      "n_remain": -1,
+      "n_decoded": 136
     }
   }
 ]
 ```
 
+</details>
+
 ### GET `/metrics`: Prometheus compatible metrics exporter
 
 This endpoint is only accessible if `--metrics` is set.
index 6aa319d2f112157a0e99de5ad999ea87874b9219..aebd886ea2a946ea8482cb9fc0066d4d7381ac9a 100644 (file)
@@ -141,7 +141,7 @@ struct slot_params {
     // Embeddings
     int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
 
-    json to_json() const {
+    json to_json(bool only_metrics = false) const {
         std::vector<std::string> samplers;
         samplers.reserve(sampling.samplers.size());
         for (const auto & sampler : sampling.samplers) {
@@ -153,9 +153,55 @@ struct slot_params {
             lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
         }
 
+        if (only_metrics) {
+            return json {
+                {"n_predict",                 n_predict},     // Server configured n_predict
+                {"seed",                      sampling.seed},
+                {"temperature",               sampling.temp},
+                {"dynatemp_range",            sampling.dynatemp_range},
+                {"dynatemp_exponent",         sampling.dynatemp_exponent},
+                {"top_k",                     sampling.top_k},
+                {"top_p",                     sampling.top_p},
+                {"min_p",                     sampling.min_p},
+                {"top_n_sigma",               sampling.top_n_sigma},
+                {"xtc_probability",           sampling.xtc_probability},
+                {"xtc_threshold",             sampling.xtc_threshold},
+                {"typical_p",                 sampling.typ_p},
+                {"repeat_last_n",             sampling.penalty_last_n},
+                {"repeat_penalty",            sampling.penalty_repeat},
+                {"presence_penalty",          sampling.penalty_present},
+                {"frequency_penalty",         sampling.penalty_freq},
+                {"dry_multiplier",            sampling.dry_multiplier},
+                {"dry_base",                  sampling.dry_base},
+                {"dry_allowed_length",        sampling.dry_allowed_length},
+                {"dry_penalty_last_n",        sampling.dry_penalty_last_n},
+                {"mirostat",                  sampling.mirostat},
+                {"mirostat_tau",              sampling.mirostat_tau},
+                {"mirostat_eta",              sampling.mirostat_eta},
+                {"max_tokens",                n_predict}, // User configured n_predict
+                {"n_keep",                    n_keep},
+                {"n_discard",                 n_discard},
+                {"ignore_eos",                sampling.ignore_eos},
+                {"stream",                    stream},
+                {"n_probs",                   sampling.n_probs},
+                {"min_keep",                  sampling.min_keep},
+                {"chat_format",               common_chat_format_name(oaicompat_chat_syntax.format)},
+                {"reasoning_format",          common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)},
+                {"reasoning_in_content",      oaicompat_chat_syntax.reasoning_in_content},
+                {"thinking_forced_open",      oaicompat_chat_syntax.thinking_forced_open},
+                {"samplers",                  samplers},
+                {"speculative.n_max",         speculative.n_max},
+                {"speculative.n_min",         speculative.n_min},
+                {"speculative.p_min",         speculative.p_min},
+                {"timings_per_token",         timings_per_token},
+                {"post_sampling_probs",       post_sampling_probs},
+                {"lora",                      lora},
+            };
+        }
+
         auto grammar_triggers = json::array();
         for (const auto & trigger : sampling.grammar_triggers) {
-            server_grammar_trigger ct(std::move(trigger));
+            server_grammar_trigger ct(trigger);
             grammar_triggers.push_back(ct.to_json());
         }
 
@@ -1572,7 +1618,26 @@ struct server_slot {
         }
     }
 
-    json to_json() const {
+    json to_json(bool only_metrics = false) const {
+        if (only_metrics) {
+            return json {
+                {"id",            id},
+                {"id_task",       id_task},
+                {"n_ctx",         n_ctx},
+                {"speculative",   can_speculate()},
+                {"is_processing", is_processing()},
+                {"params",        params.to_json(true)},
+                {"next_token",
+                    {
+                        {"has_next_token", has_next_token},
+                        {"has_new_line",   has_new_line},
+                        {"n_remain",       n_remaining},
+                        {"n_decoded",      n_decoded},
+                    }
+                },
+            };
+        }
+
         return json {
             {"id",            id},
             {"id_task",       id_task},
@@ -2874,7 +2939,7 @@ struct server_context {
                     int n_processing_slots = 0;
 
                     for (server_slot & slot : slots) {
-                        json slot_data = slot.to_json();
+                        json slot_data = slot.to_json(true);
 
                         if (slot.is_processing()) {
                             n_processing_slots++;
@@ -4271,16 +4336,20 @@ int main(int argc, char ** argv) {
         }
     };
 
-    const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
+    const auto handle_props = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
         // this endpoint is publicly available, please only return what is safe to be exposed
         json data = {
             { "default_generation_settings", ctx_server.default_generation_settings_for_props },
             { "total_slots",                 ctx_server.params_base.n_parallel },
             { "model_path",                  ctx_server.params_base.model.path },
-            { "modalities",                  json{
+            { "modalities",                  json {
                 {"vision", ctx_server.oai_parser_opt.allow_image},
                 {"audio",  ctx_server.oai_parser_opt.allow_audio},
             } },
+            { "endpoint_slots",              params.endpoint_slots },
+            { "endpoint_props",              params.endpoint_props },
+            { "endpoint_metrics",            params.endpoint_metrics },
+            { "webui",                       params.webui },
             { "chat_template",               common_chat_templates_source(ctx_server.chat_templates.get()) },
             { "bos_token",                   common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
             { "eos_token",                   common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
index 82f7215d537dba3e5b3aa249cd9d7d99fc0c4028..cda7434d7c2011e353e2a7d58f0b53a425e782c1 100644 (file)
@@ -148,6 +148,8 @@ class ServerProcess:
             server_args.append("--metrics")
         if self.server_slots:
             server_args.append("--slots")
+        else:
+            server_args.append("--no-slots")
         if self.pooling:
             server_args.extend(["--pooling", self.pooling])
         if self.model_alias: