]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
common/parser: handle reasoning budget (#20297)
authorPiotr Wilkin (ilintar) <redacted>
Wed, 11 Mar 2026 09:26:12 +0000 (10:26 +0100)
committerGitHub <redacted>
Wed, 11 Mar 2026 09:26:12 +0000 (10:26 +0100)
* v1

* Finished!

* Handlie cli

* Reasoning sampler

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <redacted>
* Less explosive terminology :)

* Add utf-8 case and tests

* common : migrate reasoning budget sampler to common

* cont : clean up

* cont : expose state and allow passing as initial state

* cont : remove unused imports

* cont : update state machine doc string

---------

Co-authored-by: Sigbjørn Skjæret <redacted>
Co-authored-by: Alde Rojas <redacted>
18 files changed:
common/CMakeLists.txt
common/arg.cpp
common/chat-auto-parser-generator.cpp
common/chat.cpp
common/chat.h
common/common.h
common/reasoning-budget.cpp [new file with mode: 0644]
common/reasoning-budget.h [new file with mode: 0644]
common/sampling.cpp
common/unicode.cpp
common/unicode.h
tests/CMakeLists.txt
tests/test-reasoning-budget.cpp [new file with mode: 0644]
tools/cli/cli.cpp
tools/server/server-common.cpp
tools/server/server-common.h
tools/server/server-context.cpp
tools/server/server-task.cpp

index 51bff1c44bfae71a6838be585503689cee227c16..75c6366c7fa65e265a5b37e83cb2747c27d7db95 100644 (file)
@@ -81,6 +81,8 @@ add_library(${TARGET} STATIC
     preset.cpp
     preset.h
     regex-partial.cpp
+    reasoning-budget.cpp
+    reasoning-budget.h
     regex-partial.h
     sampling.cpp
     sampling.h
index 0be6b28eb295e6ddb11ac4a0af7cbfc9fc3a34de..41da8563d632debe56a541da47de0d2d92d167e3 100644 (file)
@@ -2913,6 +2913,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         [](common_params & params, const std::string & value) {
             auto parsed = json::parse(value);
             for (const auto & item : parsed.items()) {
+                if (item.key() == "enable_thinking") {
+                    LOG_WRN("Setting 'enable_thinking' via --chat-template-kwargs is deprecated. "
+                            "Use --reasoning on / --reasoning off instead.\n");
+                }
                 params.default_template_kwargs[item.key()] = item.value().dump();
             }
         }
@@ -3048,14 +3052,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.reasoning_format = common_reasoning_format_from_name(value);
         }
     ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK"));
+    add_opt(common_arg(
+        {"-rea", "--reasoning"}, "[on|off|auto]",
+        "Use reasoning/thinking in the chat ('on', 'off', or 'auto', default: 'auto' (detect from template))",
+        [](common_params & params, const std::string & value) {
+            if (is_truthy(value)) {
+                params.enable_reasoning = 1;
+                params.default_template_kwargs["enable_thinking"] = "true";
+            } else if (is_falsey(value)) {
+                params.enable_reasoning = 0;
+                params.default_template_kwargs["enable_thinking"] = "false";
+            } else if (is_autoy(value)) {
+                params.enable_reasoning = -1;
+            } else {
+                throw std::invalid_argument(
+                    string_format("error: unknown value for --reasoning: '%s'\n", value.c_str()));
+            }
+        }
+    ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING"));
     add_opt(common_arg(
         {"--reasoning-budget"}, "N",
-        "controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)",
+        "token budget for thinking: -1 for unrestricted, 0 for immediate end, N>0 for token budget (default: -1)",
         [](common_params & params, int value) {
-            if (value != 0 && value != -1) { throw std::invalid_argument("invalid value"); }
+            if (value < -1) { throw std::invalid_argument("invalid value"); }
             params.reasoning_budget = value;
         }
     ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET"));
+    add_opt(common_arg(
+        {"--reasoning-budget-message"}, "MESSAGE",
+        "message injected before the end-of-thinking tag when reasoning budget is exhausted (default: none)",
+        [](common_params & params, const std::string & value) {
+            params.reasoning_budget_message = value;
+        }
+    ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
     add_opt(common_arg(
         {"--chat-template"}, "JINJA_TEMPLATE",
         string_format(
index 1c74ad30d9d99739fb7d5c6bc262fcb326666b4a..b7cf513942bbec259db8206e08099e013bd6b1e7 100644 (file)
@@ -135,7 +135,9 @@ common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) co
     if (thinking_forced_open || thinking_forced_closed) {
         // Thinking is forced open OR forced closed with enable_thinking=true
         // In both cases, expect only the closing tag (opening was in template)
-        return p.reasoning(p.until(end)) + end;
+        // However, since we might have incorrectly detected the open/close pattern,
+        // we admit an optional starting marker
+        return p.optional(p.literal(start)) + p.reasoning(p.until(end)) + end;
     }
     if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) {
         // Standard tag-based reasoning OR tools-only mode (reasoning appears with tools)
index 60fd64ff91a00980a994989ffc9c566787830375..b799912ae49aa84583a26ec5d1aa3f5e3d4a083b 100644 (file)
@@ -857,7 +857,9 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
     auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
     auto include_grammar   = true;
 
-    data.supports_thinking = true;
+    data.supports_thinking  = true;
+    data.thinking_start_tag = "[THINK]";
+    data.thinking_end_tag   = "[/THINK]";
     data.prompt            = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
     data.format            = COMMON_CHAT_FORMAT_PEG_NATIVE;
     data.preserved_tokens  = {
@@ -1165,9 +1167,11 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
                                                           const autoparser::templates_params & inputs) {
     common_chat_params data;
 
-    data.prompt            = common_chat_template_direct_apply(tmpl, inputs);
-    data.format            = COMMON_CHAT_FORMAT_PEG_NATIVE;
-    data.supports_thinking = true;
+    data.prompt             = common_chat_template_direct_apply(tmpl, inputs);
+    data.format             = COMMON_CHAT_FORMAT_PEG_NATIVE;
+    data.supports_thinking  = true;
+    data.thinking_start_tag = "<think>";
+    data.thinking_end_tag   = "</think>";
     data.preserved_tokens  = {
         "<|tool_calls_section_begin|>",
         "<|tool_calls_section_end|>",
@@ -1527,6 +1531,16 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
         autoparser.analyze_template(tmpl);
         auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
         auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
+        if (auto_params.supports_thinking) {
+            auto_params.thinking_start_tag = autoparser.reasoning.start;
+            auto_params.thinking_end_tag   = autoparser.reasoning.end;
+            // FORCED_OPEN and FORCED_CLOSED both put <think> in the generation prompt
+            // (FORCED_CLOSED forces empty <think></think> when thinking is disabled,
+            //  but forces <think> open when thinking is enabled)
+            auto_params.thinking_forced_open =
+                autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_OPEN ||
+                autoparser.reasoning.mode == autoparser::reasoning_mode::FORCED_CLOSED;
+        }
         return auto_params;
     } catch (const std::exception & e) {
         throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what());
index 005cc5c8b3fd63c3284fece71fa1b272a62321d3..930987cf77b9ab6622ed279a66e4b8772e9e1564 100644 (file)
@@ -213,6 +213,8 @@ struct common_chat_params {
     bool                                grammar_lazy         = false;
     bool                                thinking_forced_open = false;
     bool                                supports_thinking    = false;
+    std::string                         thinking_start_tag;  // e.g., "<think>"
+    std::string                         thinking_end_tag;    // e.g., "</think>"
     std::vector<common_grammar_trigger> grammar_triggers;
     std::vector<std::string>            preserved_tokens;
     std::vector<std::string>            additional_stops;
index 440eb9720070e399fde90fcabff8fb7e5d02fb7f..ffaeefd7c9423617787b83f0fdfedad5e6f1fb46 100644 (file)
@@ -235,6 +235,14 @@ struct common_params_sampling {
     std::vector<llama_logit_bias> logit_bias;     // logit biases to apply
     std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
 
+    // reasoning budget sampler parameters
+    // these are populated by the server/CLI based on chat template params
+    int32_t                  reasoning_budget_tokens   = -1;   // -1 = disabled, >= 0 = token budget
+    bool                     reasoning_budget_activate_immediately = false;
+    std::vector<llama_token> reasoning_budget_start;           // start tag token sequence
+    std::vector<llama_token> reasoning_budget_end;             // end tag token sequence
+    std::vector<llama_token> reasoning_budget_forced;          // forced sequence (message + end tag)
+
     bool backend_sampling = false;
 
     bool has_logit_bias() const {
@@ -536,7 +544,9 @@ struct common_params {
     bool use_jinja = true;                                                                                  // NOLINT
     bool enable_chat_template = true;
     common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
+    int enable_reasoning = -1; // -1 = auto, 0 = disable, 1 = enable
     int reasoning_budget = -1;
+    std::string reasoning_budget_message; // message injected before end tag when budget exhausted
     bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
     int sleep_idle_seconds = -1;   // if >0, server will sleep after this many seconds of idle time
 
diff --git a/common/reasoning-budget.cpp b/common/reasoning-budget.cpp
new file mode 100644 (file)
index 0000000..a55e4f5
--- /dev/null
@@ -0,0 +1,219 @@
+#include "reasoning-budget.h"
+#include "common.h"
+#include "unicode.h"
+
+#include "log.h"
+
+#include <cmath>
+#include <cstdint>
+#include <string>
+#include <vector>
+
+struct token_matcher {
+    std::vector<llama_token> tokens;
+    size_t pos = 0;
+
+    bool advance(llama_token token) {
+        if (tokens.empty()) {
+            return false;
+        }
+
+        if (token == tokens[pos]) {
+            pos++;
+            if (pos >= tokens.size()) {
+                pos = 0;
+                return true;
+            }
+        } else {
+            pos = 0;
+            if (token == tokens[0]) {
+                pos = 1;
+            }
+        }
+        return false;
+    }
+
+    void reset() { pos = 0; }
+};
+
+struct common_reasoning_budget_ctx {
+    const llama_vocab * vocab;
+
+    token_matcher start_matcher;
+    token_matcher end_matcher;
+    std::vector<llama_token> forced_tokens;
+
+    int32_t budget;           // maximum tokens in reasoning block
+    int32_t remaining;        // tokens remaining in budget
+
+    common_reasoning_budget_state state;
+
+    // for forcing
+    size_t force_pos;         // next position in forced_tokens to force
+};
+
+static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) {
+    return "reasoning-budget";
+}
+
+static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) {
+    auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
+
+    switch (ctx->state) {
+        case REASONING_BUDGET_IDLE:
+        {
+            if (ctx->start_matcher.advance(token)) {
+                ctx->state = REASONING_BUDGET_COUNTING;
+                ctx->remaining = ctx->budget;
+                LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);
+
+                if (ctx->remaining <= 0) {
+                    ctx->state = REASONING_BUDGET_FORCING;
+                    ctx->force_pos = 0;
+                    LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
+                }
+            }
+            break;
+        }
+        case REASONING_BUDGET_COUNTING:
+        case REASONING_BUDGET_WAITING_UTF8:
+        {
+            if (ctx->end_matcher.advance(token)) {
+                ctx->state = REASONING_BUDGET_DONE;
+                LOG_INF("reasoning-budget: deactivated (natural end)\n");
+                break;
+            }
+
+            bool utf8_complete = true;
+            if (ctx->vocab != nullptr) {
+                const std::string piece = common_token_to_piece(ctx->vocab, token, false);
+                utf8_complete = common_utf8_is_complete(piece);
+            }
+
+            if (ctx->state == REASONING_BUDGET_WAITING_UTF8) {
+                if (utf8_complete) {
+                    ctx->state = REASONING_BUDGET_FORCING;
+                    ctx->force_pos = 0;
+                    ctx->end_matcher.reset();
+                    LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
+                }
+            } else if (ctx->state == REASONING_BUDGET_COUNTING) {
+                ctx->remaining--;
+                if (ctx->remaining <= 0) {
+                    if (utf8_complete) {
+                        ctx->state = REASONING_BUDGET_FORCING;
+                        ctx->force_pos = 0;
+                        ctx->end_matcher.reset();
+                        LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
+                    } else {
+                        ctx->state = REASONING_BUDGET_WAITING_UTF8;
+                        ctx->end_matcher.reset();
+                        LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
+                    }
+                }
+            }
+            break;
+        }
+        case REASONING_BUDGET_FORCING:
+            // force_pos is advanced in apply(), not here.
+            // This ensures the first forced token isn't skipped when the sampler
+            // is initialized directly in FORCING state (e.g. COUNTING + budget=0)
+            break;
+        case REASONING_BUDGET_DONE:
+            break;
+    }
+}
+
+static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
+
+    if (ctx->state != REASONING_BUDGET_FORCING) {
+        // passthrough — don't modify logits
+        return;
+    }
+
+    if (ctx->force_pos >= ctx->forced_tokens.size()) {
+        return;
+    }
+
+    const llama_token forced = ctx->forced_tokens[ctx->force_pos];
+
+    // set all logits to -inf except the forced token
+    for (size_t i = 0; i < cur_p->size; i++) {
+        if (cur_p->data[i].id != forced) {
+            cur_p->data[i].logit = -INFINITY;
+        }
+    }
+
+    // advance to next forced token (done here rather than in accept so that
+    // the first forced token isn't skipped when starting in FORCING state)
+    ctx->force_pos++;
+    if (ctx->force_pos >= ctx->forced_tokens.size()) {
+        ctx->state = REASONING_BUDGET_DONE;
+        LOG_INF("reasoning-budget: forced sequence complete, done\n");
+    }
+}
+
+static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
+    auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
+    ctx->state = REASONING_BUDGET_IDLE;
+    ctx->remaining = ctx->budget;
+    ctx->start_matcher.reset();
+    ctx->end_matcher.reset();
+    ctx->force_pos = 0;
+}
+
+static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
+    return common_reasoning_budget_init(
+        ctx->vocab,
+        ctx->start_matcher.tokens,
+        ctx->end_matcher.tokens,
+        ctx->forced_tokens,
+        ctx->budget,
+        ctx->state);
+}
+
+static void common_reasoning_budget_free(struct llama_sampler * smpl) {
+    delete (common_reasoning_budget_ctx *) smpl->ctx;
+}
+
+static struct llama_sampler_i common_reasoning_budget_i = {
+    /* .name              = */ common_reasoning_budget_name,
+    /* .accept            = */ common_reasoning_budget_accept,
+    /* .apply             = */ common_reasoning_budget_apply,
+    /* .reset             = */ common_reasoning_budget_reset,
+    /* .clone             = */ common_reasoning_budget_clone,
+    /* .free              = */ common_reasoning_budget_free,
+    /* .backend_init      = */ nullptr,
+    /* .backend_accept    = */ nullptr,
+    /* .backend_apply     = */ nullptr,
+    /* .backend_set_input = */ nullptr,
+};
+
+struct llama_sampler * common_reasoning_budget_init(
+        const struct llama_vocab       * vocab,
+        const std::vector<llama_token> & start_tokens,
+        const std::vector<llama_token> & end_tokens,
+        const std::vector<llama_token> & forced_tokens,
+        int32_t                          budget,
+        common_reasoning_budget_state    initial_state) {
+    // promote COUNTING with budget <= 0 to FORCING
+    if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
+        initial_state = REASONING_BUDGET_FORCING;
+    }
+
+    return llama_sampler_init(
+        /* .iface = */ &common_reasoning_budget_i,
+        /* .ctx   = */ new common_reasoning_budget_ctx {
+            /* .vocab         = */ vocab,
+            /* .start_matcher = */ { start_tokens, 0 },
+            /* .end_matcher   = */ { end_tokens, 0 },
+            /* .forced_tokens = */ forced_tokens,
+            /* .budget        = */ budget,
+            /* .remaining     = */ budget,
+            /* .state         = */ initial_state,
+            /* .force_pos     = */ 0,
+        }
+    );
+}
diff --git a/common/reasoning-budget.h b/common/reasoning-budget.h
new file mode 100644 (file)
index 0000000..08ad282
--- /dev/null
@@ -0,0 +1,41 @@
+#pragma once
+
+#include "llama.h"
+
+#include <cstdint>
+#include <vector>
+
+enum common_reasoning_budget_state {
+    REASONING_BUDGET_IDLE,         // waiting for start sequence
+    REASONING_BUDGET_COUNTING,     // counting down tokens
+    REASONING_BUDGET_FORCING,      // forcing budget message + end sequence
+    REASONING_BUDGET_WAITING_UTF8, // budget exhausted, waiting for UTF-8 completion
+    REASONING_BUDGET_DONE,         // passthrough forever
+};
+
+// Creates a reasoning budget sampler that limits token generation inside a
+// reasoning block (e.g. between <think> and </think>).
+//
+// State machine: IDLE -> COUNTING -> WAITING_UTF8 -> FORCING -> DONE
+//   IDLE:         passthrough, watching for start_tokens sequence
+//   COUNTING:     counting down remaining tokens, watching for natural end_tokens
+//   WAITING_UTF8: budget exhausted, allowing tokens to complete a UTF-8 sequence
+//   FORCING:      forces forced_tokens token-by-token (all other logits -> -inf)
+//   DONE:         passthrough forever
+//
+// Parameters:
+//   vocab         - vocabulary (used for UTF-8 boundary detection; can be nullptr)
+//   start_tokens  - token sequence that activates counting
+//   end_tokens    - token sequence for natural deactivation
+//   forced_tokens - token sequence forced when budget expires
+//   budget        - max tokens allowed in the reasoning block
+//   initial_state - initial state of the sampler (e.g. IDLE or COUNTING)
+//                   note: COUNTING with budget <= 0 is promoted to FORCING
+//
+struct llama_sampler * common_reasoning_budget_init(
+        const struct llama_vocab       * vocab,
+        const std::vector<llama_token> & start_tokens,
+        const std::vector<llama_token> & end_tokens,
+        const std::vector<llama_token> & forced_tokens,
+        int32_t                          budget,
+        common_reasoning_budget_state    initial_state);
index 11a1d483980d0599adcb0a437eea131295545fcd..f849d4f61af0187d8653cd31f860506b884c41f3 100644 (file)
@@ -2,6 +2,7 @@
 
 #include "common.h"
 #include "log.h"
+#include "reasoning-budget.h"
 
 #include <algorithm>
 #include <cmath>
@@ -250,6 +251,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
         }
     }
 
+    // reasoning budget sampler — added first so it can force tokens before other samplers
+    if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
+        samplers.push_back(common_reasoning_budget_init(
+            vocab,
+            params.reasoning_budget_start,
+            params.reasoning_budget_end,
+            params.reasoning_budget_forced,
+            params.reasoning_budget_tokens,
+            params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE));
+    }
+
     if (params.has_logit_bias()) {
         samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
     }
index c0ef6d02926df5113314574ee4108a2eec998795..f71fe56783ff3f200d3f5c852ea1788fb2312902 100644 (file)
@@ -1,8 +1,10 @@
 #include "unicode.h"
+
+#include <algorithm>
 #include <cassert>
 #include <stdexcept>
-#include <vector>
 #include <string>
+#include <vector>
 
 // implementation adopted from src/unicode.cpp
 
@@ -67,6 +69,20 @@ utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t off
     return utf8_parse_result(utf8_parse_result::INVALID);
 }
 
+bool common_utf8_is_complete(const std::string & s) {
+    if (s.empty()) {
+        return true;
+    }
+    for (int i = 1; i <= std::min(4, (int)s.size()); i++) {
+        unsigned char c = s[s.size() - i];
+        if ((c & 0xC0) != 0x80) {
+            int expected = (c >= 0xF0) ? 4 : (c >= 0xE0) ? 3 : (c >= 0xC0) ? 2 : 1;
+            return i >= expected;
+        }
+    }
+    return false;
+}
+
 std::string common_unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
     std::string result;
     for (size_t i = 0; i < cps.size(); ++i) {
index 87bcc0ffcafe435cd7fdb4f941cf64e50d9438ea..9b32fa19d62bc46f7430c10b26a282f4bb9cb11e 100644 (file)
@@ -20,6 +20,9 @@ struct utf8_parse_result {
 // Returns 0 for invalid first bytes
 size_t common_utf8_sequence_length(unsigned char first_byte);
 
+// Check if a string ends with a complete UTF-8 sequence.
+bool common_utf8_is_complete(const std::string & s);
+
 // Parse a single UTF-8 codepoint from input
 utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset);
 
index 7fd895e2b6408ee3cb3cb98b685759a8fce7e752..bb0f0ef0ed8532eaed25895a00a46f424a011952 100644 (file)
@@ -149,6 +149,7 @@ endif ()
 if (NOT WIN32 OR NOT BUILD_SHARED_LIBS)
     # these tests are disabled on Windows because they use internal functions not exported with LLAMA_API (when building with shared libraries)
     llama_build_and_test(test-sampling.cpp)
+    llama_build_and_test(test-reasoning-budget.cpp)
     llama_build_and_test(test-grammar-parser.cpp)
     llama_build_and_test(test-grammar-integration.cpp)
     llama_build_and_test(test-llama-grammar.cpp)
diff --git a/tests/test-reasoning-budget.cpp b/tests/test-reasoning-budget.cpp
new file mode 100644 (file)
index 0000000..ab540a8
--- /dev/null
@@ -0,0 +1,238 @@
+#include "reasoning-budget.h"
+#include "unicode.h"
+
+#include "llama.h"
+#include "ggml.h"
+
+#ifdef NDEBUG
+#undef NDEBUG
+#endif
+
+#include <cmath>
+#include <cstddef>
+#include <cstdio>
+#include <string>
+#include <vector>
+
+// Reasoning budget sampler test helper
+// These tests use nullptr vocab which safely falls back to treating all tokens as complete
+// (The UTF-8 boundary detection logic is tested separately in test_utf8_boundary_detection)
+static void test_reasoning_budget(
+    const char * test_name,
+    const std::vector<llama_token> & sequence,
+    const std::vector<llama_token> & start_tokens,
+    const std::vector<llama_token> & end_tokens,
+    const std::vector<llama_token> & forced_tokens,
+    int32_t budget,
+    common_reasoning_budget_state initial_state,
+    size_t expected_force_start,   // token index where forcing should start (SIZE_MAX = never)
+    size_t expected_force_end      // token index where forcing should end (after this, no more forcing)
+) {
+    // Find the maximum token ID to ensure our vocab covers all tokens
+    llama_token max_token = 0;
+    for (auto t : sequence) max_token = std::max(max_token, t);
+    for (auto t : start_tokens) max_token = std::max(max_token, t);
+    for (auto t : end_tokens) max_token = std::max(max_token, t);
+    for (auto t : forced_tokens) max_token = std::max(max_token, t);
+
+    // Create a minimal sampler with mock vocabulary
+    // For this test, we use nullptr as vocab since we're testing state transitions
+    // The UTF-8 boundary check will treat all tokens as complete (safe fallback)
+    auto * sampler = common_reasoning_budget_init(
+        nullptr,  // vocab - not used for basic state machine tests
+        start_tokens,
+        end_tokens,
+        forced_tokens,
+        budget,
+        initial_state
+    );
+
+    // Create a test token data array for checking forcing behavior
+    // Vocab size must be large enough to include all tokens (start, end, forced, sequence)
+    std::vector<llama_token_data> cur;
+    const size_t n_vocab = (size_t)max_token + 1;
+    for (size_t i = 0; i < n_vocab; i++) {
+        cur.emplace_back(llama_token_data{(llama_token)i, logf((float)(i+1)), 0.0f});
+    }
+    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+
+    size_t actual_force_start = SIZE_MAX;
+    size_t actual_force_end = SIZE_MAX;
+
+    // Feed the sequence and track when forcing occurs
+    for (size_t i = 0; i < sequence.size(); i++) {
+        llama_sampler_accept(sampler, sequence[i]);
+
+        // Check if we're in forcing state by applying and seeing if logits are modified
+        cur_p.selected = -1;
+        for (size_t j = 0; j < cur.size(); j++) {
+            cur[j].logit = logf((float)(j+1));  // reset logits
+        }
+
+        llama_sampler_apply(sampler, &cur_p);
+
+        // Check if forcing is active (all logits except one should be -INFINITY)
+        size_t finite_count = 0;
+        llama_token finite_token = -1;
+        for (size_t j = 0; j < cur.size(); j++) {
+            if (std::isfinite(cur[j].logit)) {
+                finite_count++;
+                finite_token = cur[j].id;
+            }
+        }
+
+        fprintf(stderr, "    i=%zu: token=%d, finite_count=%zu, finite_token=%d\n", i, (int)sequence[i], finite_count, (int)finite_token);
+
+        if (finite_count == 1) {
+            if (actual_force_start == SIZE_MAX) {
+                actual_force_start = i;
+            }
+            actual_force_end = i;
+        } else if (actual_force_start != SIZE_MAX && actual_force_end != SIZE_MAX) {
+            // Forcing stopped
+            break;
+        }
+    }
+
+    llama_sampler_free(sampler);
+
+    // Verify forcing occurred at expected positions
+    if (expected_force_start == SIZE_MAX) {
+        if (actual_force_start != SIZE_MAX) {
+            fprintf(stderr, "Test '%s' FAILED: Expected no forcing, but forcing occurred at %zu\n", test_name, actual_force_start);
+            GGML_ASSERT(false && "Expected no forcing, but forcing occurred");
+        }
+    } else {
+        if (actual_force_start == SIZE_MAX) {
+            fprintf(stderr, "Test '%s' FAILED: Expected forcing but none occurred\n", test_name);
+            GGML_ASSERT(false && "Expected forcing but none occurred");
+        }
+        if (actual_force_start != expected_force_start) {
+            fprintf(stderr, "Test '%s' FAILED: Forcing started at %zu, expected %zu\n", test_name, actual_force_start, expected_force_start);
+            GGML_ASSERT(false && "Forcing started at wrong position");
+        }
+    }
+
+    if (expected_force_end != SIZE_MAX) {
+        if (actual_force_end < expected_force_end) {
+            fprintf(stderr, "Test '%s' FAILED: Forcing ended at %zu, expected >= %zu\n", test_name, actual_force_end, expected_force_end);
+            GGML_ASSERT(false && "Forcing ended too early");
+        }
+    }
+
+    fprintf(stderr, "  Test '%s' passed (force_start=%zu, force_end=%zu)\n", test_name, actual_force_start, actual_force_end);
+    (void)sequence;
+}
+
+// UTF-8 boundary detection unit test
+// Tests common_utf8_is_complete() from reasoning-budget.h
+static void test_utf8_boundary_detection() {
+    // Complete sequences
+    GGML_ASSERT(common_utf8_is_complete("hello"));
+    GGML_ASSERT(common_utf8_is_complete(""));
+    GGML_ASSERT(common_utf8_is_complete("\xC2\xA0"));            // complete 2-byte UTF-8 (U+00A0)
+    GGML_ASSERT(common_utf8_is_complete("\xE2\x80\x9C"));        // complete 3-byte UTF-8 (left double quote)
+    GGML_ASSERT(common_utf8_is_complete("\xF0\x9F\x98\x80"));    // complete 4-byte UTF-8 (emoji)
+    GGML_ASSERT(common_utf8_is_complete("abc\xC3\xA9"));         // ASCII + complete 2-byte
+
+    // Incomplete sequences
+    GGML_ASSERT(!common_utf8_is_complete(std::string("\xC2", 1)));            // 2-byte start, missing continuation
+    GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2\x80", 2)));        // 3-byte start + 1 cont, missing 1
+    GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2", 1)));            // 3-byte start, missing 2
+    GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F\x98", 3)));    // 4-byte start + 2 cont, missing 1
+    GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F", 2)));        // 4-byte start + 1 cont, missing 2
+    GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0", 1)));            // 4-byte start, missing 3
+    GGML_ASSERT(!common_utf8_is_complete(std::string("\x80", 1)));            // orphan continuation byte
+
+    // Mixed: ASCII followed by start of multi-byte
+    GGML_ASSERT(!common_utf8_is_complete(std::string("hello\xC3", 6)));       // ASCII + incomplete 2-byte
+    GGML_ASSERT(common_utf8_is_complete(std::string("hello\xC3\xA9", 7)));    // ASCII + complete 2-byte
+}
+
+int main(void) {
+    // Reasoning budget sampler tests
+    printf("Testing reasoning budget sampler... ");
+
+    // Test 1: Basic budget with start/end tokens - no forcing (natural end before budget exhausted)
+    {
+        const std::vector<llama_token> start = {100};  // start token
+        const std::vector<llama_token> end = {101};    // end token
+        const std::vector<llama_token> forced = {102}; // forced token (not used in this test)
+        const std::vector<llama_token> sequence = {100, 50, 51, 101, 52}; // start, two tokens, end, one more
+
+        test_reasoning_budget("natural end before budget exhausted", sequence, start, end, forced,
+            5,      // budget of 5 tokens
+            REASONING_BUDGET_IDLE,
+            SIZE_MAX, SIZE_MAX); // no forcing expected (natural end)
+    }
+
+    // Test 2: Budget exhausted, forcing should occur
+    // Flow: i=0 accept(100)->COUNTING, i=1 accept(50)->remaining=1, i=2 accept(51)->remaining=0->FORCING
+    // Forcing is active at i=2 and i=3 (when apply() is called while in FORCING state)
+    // At i=4, force_pos becomes 2 which equals forced_tokens.size(), so state becomes DONE
+    {
+        const std::vector<llama_token> start = {100};
+        const std::vector<llama_token> end = {101};
+        const std::vector<llama_token> forced = {102, 101}; // forced message + end
+        const std::vector<llama_token> sequence = {100, 50, 51, 52, 53}; // start + 4 tokens (budget=2)
+
+        test_reasoning_budget("budget exhausted forcing", sequence, start, end, forced,
+            2,      // budget of 2 tokens
+            REASONING_BUDGET_IDLE,
+            2,      // forcing starts at i=2 (after accept(51) depletes budget, apply() forces)
+            3);     // forcing continues through i=3 (at i=4 state becomes DONE)
+    }
+
+    // Test 3: Activate immediately with budget=0, forcing should start right away
+    // Flow: Since no start token in sequence, state stays IDLE (no start/end configured means passthrough)
+    // This test needs start token to be in the sequence or use activate_immediately with start token present
+    {
+        const std::vector<llama_token> start = {100};
+        const std::vector<llama_token> end = {101};
+        const std::vector<llama_token> forced = {102, 101};
+        const std::vector<llama_token> sequence = {100, 50, 51, 52}; // start token first, then 3 tokens
+
+        test_reasoning_budget("activate immediately budget=0", sequence, start, end, forced,
+            0,      // budget of 0 tokens
+            REASONING_BUDGET_COUNTING, // starts counting, promoted to FORCING since budget=0
+            0,      // forcing starts at i=0 (after accept(100), budget=0 goes straight to FORCING)
+            1);     // forcing continues through i=1 (at i=2 state becomes DONE)
+    }
+
+    // Test 4: No start/end tokens configured - passthrough (no forcing)
+    {
+        const std::vector<llama_token> start = {};
+        const std::vector<llama_token> end = {};
+        const std::vector<llama_token> forced = {102};
+        const std::vector<llama_token> sequence = {50, 51, 52, 53};
+
+        test_reasoning_budget("no start/end configured", sequence, start, end, forced,
+            2,      // budget
+            REASONING_BUDGET_IDLE,
+            SIZE_MAX, SIZE_MAX); // no forcing (no start/end configured)
+    }
+
+    // Test 5: Activate immediately with budget > 0, count down then force
+    // Flow: i=0 accept(50)->remaining=1, i=1 accept(51)->remaining=0->FORCING
+    // So forcing starts at i=1 (apply after accept sees FORCING with force_pos=0)
+    {
+        const std::vector<llama_token> start = {100};
+        const std::vector<llama_token> end = {101};
+        const std::vector<llama_token> forced = {102, 101};
+        const std::vector<llama_token> sequence = {50, 51, 52, 53};
+
+        test_reasoning_budget("activate immediately with budget", sequence, start, end, forced,
+            2,      // budget of 2 tokens
+            REASONING_BUDGET_COUNTING,
+            1,      // forcing starts at i=1 (after 2 accepts deplete budget)
+            2);     // forcing continues through i=2
+    }
+
+    printf("OK (5 tests passed)\n");
+
+    printf("Testing UTF-8 boundary detection... ");
+    test_utf8_boundary_detection();
+    printf("OK\n");
+
+    return 0;
+}
index d43d105490753521adbbe6ef1de3f8fa99174221..4c2ae7a033ca10716d7c02055ada59da344692d1 100644 (file)
@@ -57,6 +57,8 @@ struct cli_context {
     std::vector<raw_buffer> input_files;
     task_params defaults;
     bool verbose_prompt;
+    int reasoning_budget = -1;
+    std::string reasoning_budget_message;
 
     // thread for showing "loading" animation
     std::atomic<bool> loading_show;
@@ -73,6 +75,8 @@ struct cli_context {
         // defaults.return_progress = true; // TODO: show progress
 
         verbose_prompt = params.verbose_prompt;
+        reasoning_budget = params.reasoning_budget;
+        reasoning_budget_message = params.reasoning_budget_message;
     }
 
     std::string generate_completion(result_timings & out_timings) {
@@ -95,6 +99,24 @@ struct cli_context {
                 task.params.chat_parser_params.parser.load(chat_params.parser);
             }
 
+            // reasoning budget sampler
+            if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) {
+                const llama_vocab * vocab = llama_model_get_vocab(
+                    llama_get_model(ctx_server.get_llama_context()));
+
+                task.params.sampling.reasoning_budget_tokens = reasoning_budget;
+                task.params.sampling.reasoning_budget_activate_immediately = chat_params.thinking_forced_open;
+
+                if (!chat_params.thinking_start_tag.empty()) {
+                    task.params.sampling.reasoning_budget_start =
+                        common_tokenize(vocab, chat_params.thinking_start_tag, false, true);
+                }
+                task.params.sampling.reasoning_budget_end =
+                    common_tokenize(vocab, chat_params.thinking_end_tag, false, true);
+                task.params.sampling.reasoning_budget_forced =
+                    common_tokenize(vocab, reasoning_budget_message + chat_params.thinking_end_tag, false, true);
+            }
+
             rd.post_task({std::move(task)});
         }
 
index 5b8895b341ef852d1c9be987c570f592f6f2eb52..bd203228cce05315810da7d28e2d221a9c4fa69f 100644 (file)
@@ -1101,6 +1101,22 @@ json oaicompat_chat_params_parse(
         llama_params["chat_parser"] = chat_params.parser;
     }
 
+    // Reasoning budget: pass parameters through to sampling layer
+    {
+        int reasoning_budget = opt.reasoning_budget;
+        if (reasoning_budget == -1 && body.contains("thinking_budget_tokens")) {
+            reasoning_budget = json_value(body, "thinking_budget_tokens", -1);
+        }
+
+        if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) {
+            llama_params["reasoning_budget_tokens"] = reasoning_budget;
+            llama_params["reasoning_budget_start_tag"] = chat_params.thinking_start_tag;
+            llama_params["reasoning_budget_end_tag"] = chat_params.thinking_end_tag;
+            llama_params["reasoning_budget_message"] = opt.reasoning_budget_message;
+            llama_params["reasoning_budget_activate_immediately"] = chat_params.thinking_forced_open;
+        }
+    }
+
     // Handle "logprobs" field
     // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
     if (json_value(body, "logprobs", false)) {
index a234541e199fd012f10e826c8e8eb666acd53614..3e56b3d856dcad6df326a0479f77f4f364659ebc 100644 (file)
@@ -287,6 +287,8 @@ struct server_chat_params {
     bool allow_image;
     bool allow_audio;
     bool enable_thinking = true;
+    int  reasoning_budget = -1;
+    std::string reasoning_budget_message;
     std::string media_path;
 };
 
index b86e7e608e0147c77a0c7b11b4c250a94371ded1..b4373c101b428f43ab1b313f75b53ced04af52dc 100644 (file)
@@ -893,9 +893,10 @@ private:
             }
 
             // thinking is enabled if:
-            // 1. It's not explicitly disabled (reasoning_budget == 0)
+            // 1. It's not explicitly disabled via --reasoning off
             // 2. The chat template supports it
-            const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
+            const bool template_supports_thinking = params_base.use_jinja && common_chat_templates_support_enable_thinking(chat_templates.get());
+            const bool enable_thinking = params_base.enable_reasoning != 0 && template_supports_thinking;
             SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking);
 
             chat_params = {
@@ -907,6 +908,8 @@ private:
                 /* allow_image           */ mctx ? mtmd_support_vision(mctx) : false,
                 /* allow_audio           */ mctx ? mtmd_support_audio (mctx) : false,
                 /* enable_thinking       */ enable_thinking,
+                /* reasoning_budget      */ params_base.reasoning_budget,
+                /* reasoning_budget_msg  */ params_base.reasoning_budget_message,
                 /* media_path            */ params_base.media_path,
             };
         }
index 9d6e422d62c4245a9ad62691bcad67a79a659cac..b3d510977b2b83d0076afbeeb21981be113dd3b9 100644 (file)
@@ -462,6 +462,34 @@ task_params server_task::params_from_json_cmpl(
         }
     }
 
+    // Parse reasoning budget sampler parameters
+    {
+        const int32_t budget = json_value(data, "reasoning_budget_tokens", (int32_t) -1);
+        if (budget >= 0) {
+            const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string());
+            const auto end_tag   = json_value(data, "reasoning_budget_end_tag", std::string());
+            const auto message   = json_value(data, "reasoning_budget_message", std::string());
+            const bool activate_imm   = json_value(data, "reasoning_budget_activate_immediately", false);
+
+            params.sampling.reasoning_budget_tokens = budget;
+            params.sampling.reasoning_budget_activate_immediately = activate_imm;
+
+            if (!start_tag.empty()) {
+                params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true);
+            }
+            if (!end_tag.empty()) {
+                params.sampling.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true);
+                params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true);
+            }
+
+            SRV_DBG("reasoning budget: tokens=%d, activate_immediately=%s, start=%zu toks, end=%zu toks, forced=%zu toks\n",
+                budget, activate_imm ? "true" : "false",
+                params.sampling.reasoning_budget_start.size(),
+                params.sampling.reasoning_budget_end.size(),
+                params.sampling.reasoning_budget_forced.size());
+        }
+    }
+
     {
         params.sampling.logit_bias.clear();