]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
common : inhibit lazy grammar sampler while reasoning is active (#20970)
authorAldehir Rojas <redacted>
Fri, 27 Mar 2026 17:30:40 +0000 (12:30 -0500)
committerGitHub <redacted>
Fri, 27 Mar 2026 17:30:40 +0000 (18:30 +0100)
* common : inhibit grammar while reasoning budget is active

* cont : update force_pos in accept

* cont : fix tests

* cont : tweak should apply logic

* cont : return early not using grammar sampler

* Add tests

* cont : prevent backend sampling when reasoning budget enabled

* cont : fix typo

---------

Co-authored-by: Piotr Wilkin <redacted>
common/reasoning-budget.cpp
common/reasoning-budget.h
common/sampling.cpp
tests/test-chat.cpp
tests/test-reasoning-budget.cpp
tools/cli/cli.cpp
tools/server/server-common.cpp
tools/server/server-task.cpp

index 2ef744278aeb0853268b91ef486f1f4168c49e2b..cc408a6869988ea40adb6484570e38d1917b5b9e 100644 (file)
@@ -115,9 +115,11 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
             break;
         }
         case REASONING_BUDGET_FORCING:
-            // force_pos is advanced in apply(), not here.
-            // This ensures the first forced token isn't skipped when the sampler
-            // is initialized directly in FORCING state (e.g. COUNTING + budget=0)
+            ctx->force_pos++;
+            if (ctx->force_pos >= ctx->forced_tokens.size()) {
+                ctx->state = REASONING_BUDGET_DONE;
+                LOG_INF("reasoning-budget: forced sequence complete, done\n");
+            }
             break;
         case REASONING_BUDGET_DONE:
             break;
@@ -144,14 +146,6 @@ static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_tok
             cur_p->data[i].logit = -INFINITY;
         }
     }
-
-    // advance to next forced token (done here rather than in accept so that
-    // the first forced token isn't skipped when starting in FORCING state)
-    ctx->force_pos++;
-    if (ctx->force_pos >= ctx->forced_tokens.size()) {
-        ctx->state = REASONING_BUDGET_DONE;
-        LOG_INF("reasoning-budget: forced sequence complete, done\n");
-    }
 }
 
 static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
@@ -261,3 +255,10 @@ struct llama_sampler * common_reasoning_budget_init(
         common_reasoning_budget_state    initial_state) {
     return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
 }
+
+common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl) {
+    if (!smpl) {
+        return REASONING_BUDGET_IDLE;
+    }
+    return ((const common_reasoning_budget_ctx *)smpl->ctx)->state;
+}
index 130afdea4ac30a7d3767e8c04ed7b06d41b78b05..ee1a30ed3c1b5818e75d691dbd91623a796cf962 100644 (file)
@@ -51,3 +51,5 @@ struct llama_sampler * common_reasoning_budget_init(
         const std::vector<llama_token> & forced_tokens,
         int32_t                          budget,
         common_reasoning_budget_state    initial_state);
+
+common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl);
index 012e2126604f09df649152d1f1047d85acdd672c..5259c5f3c6f58d49285869cafecd03112870763f 100644 (file)
@@ -7,6 +7,7 @@
 
 #include <algorithm>
 #include <cctype>
+#include <climits>
 #include <cmath>
 #include <cstring>
 #include <unordered_map>
@@ -109,6 +110,7 @@ struct common_sampler {
     common_params_sampling params;
 
     struct llama_sampler * grmr;
+    struct llama_sampler * rbudget;
     struct llama_sampler * chain;
 
     ring_buffer<llama_token> prev;
@@ -188,6 +190,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
     lparams.no_perf = params.no_perf;
 
     llama_sampler * grmr = nullptr;
+    llama_sampler * rbudget = nullptr;
     llama_sampler * chain = llama_sampler_chain_init(lparams);
 
     std::vector<llama_sampler *> samplers;
@@ -270,7 +273,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
             }
         }
 
-        if (grmr) {
+        if (grmr && !params.grammar_lazy) {
             try {
                 for (const auto & token : prefill_tokens) {
                     llama_sampler_accept(grmr, token);
@@ -284,15 +287,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
         }
     }
 
-    // reasoning budget sampler — added first so it can force tokens before other samplers
-    if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
-        samplers.push_back(common_reasoning_budget_init(
+    // reasoning budget sampler
+    if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty()) {
+        rbudget = common_reasoning_budget_init(
             vocab,
             params.reasoning_budget_start,
             params.reasoning_budget_end,
             params.reasoning_budget_forced,
-            params.reasoning_budget_tokens,
-            prefill_tokens));
+            params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens,
+            prefill_tokens);
     }
 
     if (params.has_logit_bias()) {
@@ -383,6 +386,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
     auto * result = new common_sampler {
         /* .params  = */ params,
         /* .grmr    = */ grmr,
+        /* .rbudget = */ rbudget,
         /* .chain   = */ chain,
         /* .prev    = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
         /* .cur     = */ {},
@@ -398,11 +402,27 @@ void common_sampler_free(struct common_sampler * gsmpl) {
     }
 
     llama_sampler_free(gsmpl->grmr);
+    llama_sampler_free(gsmpl->rbudget);
     llama_sampler_free(gsmpl->chain);
 
     delete gsmpl;
 }
 
+static bool grammar_should_apply(struct common_sampler * gsmpl) {
+    if (!gsmpl->grmr) {
+        return false;
+    }
+    if (!gsmpl->rbudget) {
+        return true;
+    }
+    if (gsmpl->params.grammar_lazy) {
+        // if grammar is lazy, only apply when reasoning budget is not active
+        const auto state = common_reasoning_budget_get_state(gsmpl->rbudget);
+        return state == REASONING_BUDGET_IDLE || state == REASONING_BUDGET_DONE;
+    }
+    return true;
+}
+
 void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
     if (!gsmpl) {
         return;
@@ -410,6 +430,11 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
 
     const auto tm = gsmpl->tm();
 
+    // grammar_should_apply() checks the reasoning budget state, so calculate this before we accept
+    accept_grammar = accept_grammar && grammar_should_apply(gsmpl);
+
+    llama_sampler_accept(gsmpl->rbudget, token);
+
     if (gsmpl->grmr && accept_grammar) {
         llama_sampler_accept(gsmpl->grmr, token);
     }
@@ -431,6 +456,7 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
     return new common_sampler {
         /* .params  = */ gsmpl->params,
         /* .grmr    = */ llama_sampler_clone(gsmpl->grmr),
+        /* .rbudget = */ llama_sampler_clone(gsmpl->rbudget),
         /* .chain   = */ llama_sampler_clone(gsmpl->chain),
         /* .prev    = */ gsmpl->prev,
         /* .cur     = */ gsmpl->cur,
@@ -500,6 +526,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
     llama_token id = LLAMA_TOKEN_NULL;
 
     auto & grmr  = gsmpl->grmr;
+    auto & rbudget = gsmpl->rbudget;
     auto & chain = gsmpl->chain;
     auto & cur_p = gsmpl->cur_p; // initialized by set_logits
 
@@ -511,7 +538,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
         if (id != LLAMA_TOKEN_NULL) {
             LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
 
-            GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
+            GGML_ASSERT(!gsmpl->grmr    && "using grammar in combination with backend sampling is not supported");
+            GGML_ASSERT(!gsmpl->rbudget && "using reasoning budget in combination with backend sampling is not supported");
 
             // TODO: simplify
             gsmpl->cur.resize(1);
@@ -524,7 +552,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
 
     gsmpl->set_logits(ctx, idx);
 
-    if (grammar_first) {
+    // apply reasoning budget first
+    llama_sampler_apply(rbudget, &cur_p);
+
+    if (grammar_first && grammar_should_apply(gsmpl)) {
         llama_sampler_apply(grmr, &cur_p);
     }
 
@@ -532,7 +563,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
 
     id = cur_p.data[cur_p.selected].id;
 
-    if (grammar_first) {
+    if (grammar_first || !grammar_should_apply(gsmpl)) {
         return id;
     }
 
@@ -553,7 +584,12 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
     // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
     gsmpl->set_logits(ctx, idx);
 
-    llama_sampler_apply(grmr,  &cur_p);
+    llama_sampler_apply(rbudget,  &cur_p);
+
+    if (grammar_should_apply(gsmpl)) {
+        llama_sampler_apply(grmr,  &cur_p);
+    }
+
     llama_sampler_apply(chain, &cur_p);
 
     GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
index 575d240791747a6e8a676c271898140f7a9a3e06..74f078f5edda8301b1c02990272a6c039e1ea6ca 100644 (file)
@@ -936,75 +936,158 @@ static void test_peg_parser(common_chat_templates *                      tmpls,
             throw std::runtime_error("Failed to build grammar: " + parser.params_.grammar);
         }
 
+        // In production, grammar triggers match against the full generated text
+        // including the generation prompt. All positions are in full_input coordinates.
+        const auto & gen_prompt = parser.params_.generation_prompt;
+        std::string full_input = gen_prompt + tc.input;
+
+        // Determine whether the reasoning-budget sampler path applies: tool-call grammar
+        // with all WORD triggers and thinking tags present. In production, the reasoning
+        // budget sampler inhibits grammar application while inside thinking blocks —
+        // triggers inside <think>...</think> are suppressed.
+        bool use_reasoning_budget_path = false;
+        if (parser.params_.grammar_lazy && !parser.params_.thinking_end_tag.empty()) {
+            use_reasoning_budget_path = true;
+            for (const auto & trigger : parser.params_.grammar_triggers) {
+                if (trigger.type != COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
+                    use_reasoning_budget_path = false;
+                    break;
+                }
+            }
+        }
+
         // Find the earliest trigger position to determine the constrained portion
         auto earliest_trigger_pos = std::string::npos;
-        for (const auto & trigger : parser.params_.grammar_triggers) {
-            size_t      pos = std::string::npos;
-            std::smatch match;
-            switch (trigger.type) {
-                case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
-                    {
-                        const auto & word = trigger.value;
-                        pos               = tc.input.find(word);
-                        break;
-                    }
-                case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
-                    {
-                        const auto & pattern = std::regex(trigger.value);
-                        if (std::regex_search(tc.input, match, pattern)) {
-                            pos = match.position(pattern.mark_count());
+
+        if (use_reasoning_budget_path) {
+            // Reasoning-budget path: simulate thinking-aware trigger detection.
+            // Walk through full_input tracking thinking state; only match triggers
+            // when outside thinking blocks.
+            const auto & think_start = parser.params_.thinking_start_tag;
+            const auto & think_end   = parser.params_.thinking_end_tag;
+
+            bool in_thinking = false;
+            for (size_t i = 0; i < full_input.size(); ++i) {
+                if (!in_thinking && !think_start.empty()
+                        && full_input.compare(i, think_start.size(), think_start) == 0) {
+                    in_thinking = true;
+                    i += think_start.size() - 1;
+                    continue;
+                }
+                if (in_thinking && full_input.compare(i, think_end.size(), think_end) == 0) {
+                    in_thinking = false;
+                    i += think_end.size() - 1;
+                    continue;
+                }
+                if (in_thinking) {
+                    continue;
+                }
+                // Outside thinking — check if any trigger word starts here
+                for (const auto & trigger : parser.params_.grammar_triggers) {
+                    if (full_input.compare(i, trigger.value.size(), trigger.value) == 0) {
+                        if (earliest_trigger_pos == std::string::npos || i < earliest_trigger_pos) {
+                            earliest_trigger_pos = i;
                         }
-                        break;
                     }
-                case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
-                    {
-                        const auto & pattern = trigger.value;
-                        if (std::regex_match(tc.input, match, std::regex(pattern))) {
-                            auto mpos = std::string::npos;
-                            for (size_t i = 1; i < match.size(); ++i) {
-                                if (match[i].length() > 0) {
-                                    mpos = match.position(i);
+                }
+                if (earliest_trigger_pos != std::string::npos) {
+                    break;  // found the earliest
+                }
+            }
+
+            // If the reasoning-budget path found no trigger outside thinking but the test
+            // expects tool calls, this template nests tool calls inside thinking
+            // blocks (e.g. Kimi). Fall back to the legacy path for this case.
+            if (earliest_trigger_pos == std::string::npos && !tc.expect.tool_calls.empty()) {
+                use_reasoning_budget_path = false;
+            }
+        }
+
+        if (!use_reasoning_budget_path) {
+            // Legacy path: find triggers without thinking-awareness
+            for (const auto & trigger : parser.params_.grammar_triggers) {
+                size_t      pos = std::string::npos;
+                std::smatch match;
+                switch (trigger.type) {
+                    case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
+                        {
+                            const auto & word = trigger.value;
+                            pos               = full_input.find(word);
+                            break;
+                        }
+                    case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
+                        {
+                            const auto & compiled = std::regex(trigger.value);
+                            if (std::regex_search(full_input, match, compiled)) {
+                                pos = match.position(compiled.mark_count());
+                            }
+                            break;
+                        }
+                    case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
+                        {
+                            // In production, PATTERN_FULL triggers are checked against
+                            // the text generated so far, growing token by token. Simulate
+                            // by trying every prefix of full_input.
+                            const auto & compiled = std::regex(trigger.value);
+                            for (size_t end = gen_prompt.size(); end <= full_input.size(); ++end) {
+                                std::string prefix = full_input.substr(0, end);
+                                if (std::regex_match(prefix, match, compiled)) {
+                                    pos = std::string::npos;
+                                    for (size_t gi = 1; gi < match.size(); ++gi) {
+                                        if (match[gi].length() > 0) {
+                                            pos = match.position(gi);
+                                            break;
+                                        }
+                                    }
+                                    if (pos == std::string::npos) {
+                                        pos = match.position(0);
+                                    }
                                     break;
                                 }
                             }
-                            if (mpos == std::string::npos) {
-                                mpos = match.position(0);
-                            }
-                            pos = mpos;
+                            break;
                         }
-                        break;
+                    default:
+                        throw std::runtime_error("Unknown trigger type");
+                }
+                if (pos != std::string::npos) {
+                    if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
+                        earliest_trigger_pos = pos;
                     }
-                default:
-                    throw std::runtime_error("Unknown trigger type");
-            }
-            if (pos != std::string::npos) {
-                if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
-                    earliest_trigger_pos = pos;
                 }
             }
         }
 
-        // Determine the constrained portion of input to test against grammar
-        std::string constrained = tc.input;
+        // If the test expects tool calls and the grammar is lazy, the trigger must fire.
+        // Otherwise the grammar would never activate in production and tool calls wouldn't
+        // be constrained. A silent skip here would hide broken triggers.
+        if (parser.params_.grammar_lazy && !tc.expect.tool_calls.empty() && !tc.is_partial
+                && earliest_trigger_pos == std::string::npos) {
+            std::string trigger_desc;
+            for (const auto & trigger : parser.params_.grammar_triggers) {
+                trigger_desc += "\n  [type=" + std::to_string(trigger.type) + "] " + trigger.value;
+            }
+            throw std::runtime_error(
+                "Grammar trigger did not fire, but test expects tool calls (lazy grammar).\n"
+                ">>> Input: " + full_input + "\n"
+                ">>> Triggers (" + std::to_string(parser.params_.grammar_triggers.size()) + "):" + trigger_desc);
+        }
+
+        // Determine the constrained portion of input to test against grammar.
+        // If the trigger position falls inside the generation prompt, the grammar
+        // sampler was already active before model output began — constrain from the
+        // start of the model output (i.e. tc.input).
+        std::string constrained = full_input;
         bool grammar_triggered = false;
         if (earliest_trigger_pos != std::string::npos) {
-            constrained = tc.input.substr(earliest_trigger_pos);
+            auto constrain_from = std::max(earliest_trigger_pos, gen_prompt.size());
+            constrained = full_input.substr(constrain_from);
             grammar_triggered = true;
         } else if (!parser.params_.grammar_lazy) {
             // For non-lazy grammars, the entire input should match
             grammar_triggered = true;
         }
 
-        // For non-lazy grammars, prepend reasoning prefill to grammar input, just like
-        // PEG parsing does. The grammar includes the full reasoning pattern (e.g. optional
-        // <think>...</think>), but the model output may start mid-reasoning if the template
-        // already placed the opening tag in the prompt.
-        // For lazy grammars, the grammar only activates from the trigger position, so the
-        // reasoning prefill is irrelevant — reasoning is handled by the PEG parser.
-        if (!parser.params_.generation_prompt.empty() && earliest_trigger_pos == std::string::npos) {
-            constrained = parser.params_.generation_prompt + constrained;
-        }
-
         // Test the constrained portion against the grammar
         if (grammar_triggered && !tc.is_partial) {
             auto result = match_string_detailed(constrained, grammar.get());
@@ -1323,6 +1406,19 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
             .expect_reasoning("I need to output the invoice details in JSON")
             .expect_content(R"({"amount": 123.45, "date": "2025-12-03"})")
             .run();
+
+        // fake tool call marker in reasoning
+        tst.test(
+               "[THINK]Let me think about [TOOL_CALLS]special_function[ARGS]{\"arg1\":1} and more[/THINK]"
+               R"([TOOL_CALLS]special_function[ARGS]{"arg1": 1})")
+            .reasoning_format(COMMON_REASONING_FORMAT_AUTO)
+            .enable_thinking(true)
+            .tools({ special_function_tool })
+            .expect_reasoning("Let me think about [TOOL_CALLS]special_function[ARGS]{\"arg1\":1} and more")
+            .expect_tool_calls({
+                { "special_function", R"({"arg1": 1})", {} },
+            })
+            .run();
     }
 
     {
@@ -1425,6 +1521,50 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
             .expect_reasoning("I need to output the invoice details in JSON")
             .expect_content(R"({"amount": 123.45, "date": "2025-12-03"})")
             .run();
+
+        // tool call segment in reasoning
+        tst.test(
+               "Let's call a tool: <tool_call>\n"
+               "<function=python>\n"
+               "<parameter=code>\n"
+               "def hello():\n"
+               "    print(\"Not the real call!\")\n"
+               "\n"
+               "hello()\n"
+               "</parameter>\n"
+               "</function>\n"
+               "</tool_call></think>\n"
+               "<tool_call>\n"
+               "<function=python>\n"
+               "<parameter=code>\n"
+               "def hello():\n"
+               "    print(\"Hello, world!\")\n"
+               "\n"
+               "hello()\n"
+               "</parameter>\n"
+               "</function>\n"
+               "</tool_call>"
+            )
+            .enable_thinking(true)
+            .reasoning_format(COMMON_REASONING_FORMAT_AUTO)
+            .tools({
+                python_tool
+        })
+            .expect_reasoning("Let's call a tool: <tool_call>\n"
+               "<function=python>\n"
+               "<parameter=code>\n"
+               "def hello():\n"
+               "    print(\"Not the real call!\")\n"
+               "\n"
+               "hello()\n"
+               "</parameter>\n"
+               "</function>\n"
+               "</tool_call>")
+            .expect_tool_calls({
+                { "python", "{\"code\": \"def hello():\\n    print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} },
+            })
+            .run();
+
     }
 
     {
@@ -2297,6 +2437,19 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
             .tools({ empty_args_tool })
             .expect(simple_assist_msg("", "", "empty_args", "{}"))
             .run();
+
+        // fake tool call marker in reasoning
+        tst.test(
+               "<think>Let me think about <|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|> hmm</think>"
+               "<|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>")
+            .enable_thinking(true)
+            .reasoning_format(COMMON_REASONING_FORMAT_AUTO)
+            .tools({ special_function_tool })
+            .expect_reasoning("Let me think about <|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|> hmm")
+            .expect_tool_calls({
+                { "special_function", R"({"arg1": 1})", {} },
+            })
+            .run();
     }
 
     // Apertus-8B-Instruct tests - FUNC_NAME_AS_KEY format
index ab540a846302cfb7b41ff15105f03542a823148f..3028fb4d8f06ba9c60cd53c694027796f608e75a 100644 (file)
@@ -61,8 +61,6 @@ static void test_reasoning_budget(
 
     // Feed the sequence and track when forcing occurs
     for (size_t i = 0; i < sequence.size(); i++) {
-        llama_sampler_accept(sampler, sequence[i]);
-
         // Check if we're in forcing state by applying and seeing if logits are modified
         cur_p.selected = -1;
         for (size_t j = 0; j < cur.size(); j++) {
@@ -81,6 +79,8 @@ static void test_reasoning_budget(
             }
         }
 
+        llama_sampler_accept(sampler, sequence[i]);
+
         fprintf(stderr, "    i=%zu: token=%d, finite_count=%zu, finite_token=%d\n", i, (int)sequence[i], finite_count, (int)finite_token);
 
         if (finite_count == 1) {
@@ -167,9 +167,9 @@ int main(void) {
     }
 
     // Test 2: Budget exhausted, forcing should occur
-    // Flow: i=0 accept(100)->COUNTING, i=1 accept(50)->remaining=1, i=2 accept(51)->remaining=0->FORCING
-    // Forcing is active at i=2 and i=3 (when apply() is called while in FORCING state)
-    // At i=4, force_pos becomes 2 which equals forced_tokens.size(), so state becomes DONE
+    // Flow: i=0 apply()->passthrough, accept(100)->COUNTING; i=1 accept(50)->remaining=1
+    // i=2 accept(51)->remaining=0->FORCING; i=3 apply() forces token[0]; i=4 apply() forces token[1]
+    // At i=4, accept() advances force_pos to 2 which equals forced_tokens.size(), so state becomes DONE
     {
         const std::vector<llama_token> start = {100};
         const std::vector<llama_token> end = {101};
@@ -179,13 +179,12 @@ int main(void) {
         test_reasoning_budget("budget exhausted forcing", sequence, start, end, forced,
             2,      // budget of 2 tokens
             REASONING_BUDGET_IDLE,
-            2,      // forcing starts at i=2 (after accept(51) depletes budget, apply() forces)
-            3);     // forcing continues through i=3 (at i=4 state becomes DONE)
+            3,      // forcing starts at i=3 (accept at i=2 depletes budget, apply at i=3 forces)
+            4);     // forcing continues through i=4 (accept at i=4 transitions to DONE)
     }
 
     // Test 3: Activate immediately with budget=0, forcing should start right away
-    // Flow: Since no start token in sequence, state stays IDLE (no start/end configured means passthrough)
-    // This test needs start token to be in the sequence or use activate_immediately with start token present
+    // Flow: init promotes COUNTING+budget=0 to FORCING, so apply() sees FORCING at i=0
     {
         const std::vector<llama_token> start = {100};
         const std::vector<llama_token> end = {101};
@@ -195,8 +194,8 @@ int main(void) {
         test_reasoning_budget("activate immediately budget=0", sequence, start, end, forced,
             0,      // budget of 0 tokens
             REASONING_BUDGET_COUNTING, // starts counting, promoted to FORCING since budget=0
-            0,      // forcing starts at i=0 (after accept(100), budget=0 goes straight to FORCING)
-            1);     // forcing continues through i=1 (at i=2 state becomes DONE)
+            0,      // forcing starts at i=0 (initialized in FORCING, apply forces immediately)
+            1);     // forcing continues through i=1 (accept at i=1 transitions to DONE)
     }
 
     // Test 4: No start/end tokens configured - passthrough (no forcing)
@@ -214,7 +213,7 @@ int main(void) {
 
     // Test 5: Activate immediately with budget > 0, count down then force
     // Flow: i=0 accept(50)->remaining=1, i=1 accept(51)->remaining=0->FORCING
-    // So forcing starts at i=1 (apply after accept sees FORCING with force_pos=0)
+    // Forcing starts at i=2 (apply sees FORCING after accept at i=1 transitioned)
     {
         const std::vector<llama_token> start = {100};
         const std::vector<llama_token> end = {101};
@@ -224,8 +223,8 @@ int main(void) {
         test_reasoning_budget("activate immediately with budget", sequence, start, end, forced,
             2,      // budget of 2 tokens
             REASONING_BUDGET_COUNTING,
-            1,      // forcing starts at i=1 (after 2 accepts deplete budget)
-            2);     // forcing continues through i=2
+            2,      // forcing starts at i=2 (after 2 accepts deplete budget, apply at i=2 forces)
+            3);     // forcing continues through i=3
     }
 
     printf("OK (5 tests passed)\n");
index 65d14e97288fde7454ef03a2703f7ce2055ae5e3..f5b4426f6f6f16a2e27cc7dbcde2b39c26436895 100644 (file)
@@ -100,7 +100,7 @@ struct cli_context {
             }
 
             // reasoning budget sampler
-            if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) {
+            if (!chat_params.thinking_end_tag.empty()) {
                 const llama_vocab * vocab = llama_model_get_vocab(
                     llama_get_model(ctx_server.get_llama_context()));
 
index e01c8c53df77b55bd1fc2ffae1a8f5ceee44b3b8..ed5e306fc5b2aa10b6ab3bba0616272cb8be48d5 100644 (file)
@@ -1110,7 +1110,7 @@ json oaicompat_chat_params_parse(
             reasoning_budget = json_value(body, "thinking_budget_tokens", -1);
         }
 
-        if (reasoning_budget >= 0 && !chat_params.thinking_end_tag.empty()) {
+        if (!chat_params.thinking_end_tag.empty()) {
             llama_params["reasoning_budget_tokens"] = reasoning_budget;
             llama_params["reasoning_budget_start_tag"] = chat_params.thinking_start_tag;
             llama_params["reasoning_budget_end_tag"] = chat_params.thinking_end_tag;
index 7d543b9292b8752811734016a5cfd3cde9d20992..3018ac90f8cf2fd408cef73e9638a74ee2c7572e 100644 (file)
@@ -478,19 +478,17 @@ task_params server_task::params_from_json_cmpl(
     // Parse reasoning budget sampler parameters
     {
         const int32_t budget = json_value(data, "reasoning_budget_tokens", (int32_t) -1);
-        if (budget >= 0) {
-            const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string());
-            const auto end_tag   = json_value(data, "reasoning_budget_end_tag", std::string());
-            const auto message   = json_value(data, "reasoning_budget_message", std::string());
-            params.sampling.reasoning_budget_tokens = budget;
-
-            if (!start_tag.empty()) {
-                params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true);
-            }
-            if (!end_tag.empty()) {
-                params.sampling.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true);
-                params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true);
-            }
+        const auto start_tag = json_value(data, "reasoning_budget_start_tag", std::string());
+        const auto end_tag   = json_value(data, "reasoning_budget_end_tag", std::string());
+        const auto message   = json_value(data, "reasoning_budget_message", std::string());
+        params.sampling.reasoning_budget_tokens = budget;
+
+        if (!start_tag.empty()) {
+            params.sampling.reasoning_budget_start = common_tokenize(vocab, start_tag, false, true);
+        }
+        if (!end_tag.empty()) {
+            params.sampling.reasoning_budget_end = common_tokenize(vocab, end_tag, false, true);
+            params.sampling.reasoning_budget_forced = common_tokenize(vocab, message + end_tag, false, true);
 
             SRV_DBG("reasoning budget: tokens=%d, generation_prompt='%s', start=%zu toks, end=%zu toks, forced=%zu toks\n",
                 budget, params.sampling.generation_prompt.c_str(),