]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
chat : support Magistral thinking (#16413)
authorPascal <redacted>
Fri, 3 Oct 2025 18:51:48 +0000 (20:51 +0200)
committerGitHub <redacted>
Fri, 3 Oct 2025 18:51:48 +0000 (21:51 +0300)
* feat: added a dedicated Magistral chat format that preserves [THINK] spans, parses reasoning before tool calls

* feat: new flow in the chat template test suite for Magistral

common/chat.cpp
common/chat.h
tests/test-chat.cpp

index 87212322ec2488f8b6c767cb8e241e8e06910125..afbb2a2bdd3c41233da3112127da473980d897d3 100644 (file)
@@ -625,6 +625,7 @@ const char * common_chat_format_name(common_chat_format format) {
         case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only";
         case COMMON_CHAT_FORMAT_GENERIC: return "Generic";
         case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo";
+        case COMMON_CHAT_FORMAT_MAGISTRAL: return "Magistral";
         case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
         case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
         case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
@@ -984,6 +985,65 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
     data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
     return data;
 }
+
+static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
+    common_chat_params data;
+    data.prompt = apply(tmpl, inputs);
+    data.format = COMMON_CHAT_FORMAT_MAGISTRAL;
+    data.preserved_tokens = {
+        "[THINK]",
+        "[/THINK]",
+    };
+
+    if (inputs.tools.is_array() && !inputs.tools.empty()) {
+        data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
+        data.grammar = build_grammar([&](const common_grammar_builder & builder) {
+            auto schemas = json::array();
+            foreach_function(inputs.tools, [&](const json & tool) {
+                const auto & function = tool.at("function");
+                schemas.push_back({
+                    {"type", "object"},
+                    {"properties", {
+                        {"name", {
+                            {"type", "string"},
+                            {"const", function.at("name")},
+                        }},
+                        {"arguments", function.at("parameters")},
+                        {"id", {
+                            {"type", "string"},
+                            {"pattern", "^[a-zA-Z0-9]{9}$"},
+                        }},
+                    }},
+                    {"required", json::array({"name", "arguments", "id"})},
+                });
+            });
+            auto schema = json {
+                {"type", "array"},
+                {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
+                {"minItems", 1},
+            };
+            if (!inputs.parallel_tool_calls) {
+                schema["maxItems"] = 1;
+            }
+            builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
+        });
+        data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"});
+        data.preserved_tokens.push_back("[TOOL_CALLS]");
+    } else {
+        data.grammar_lazy = false;
+        if (!inputs.json_schema.is_null()) {
+            if (!inputs.grammar.empty()) {
+                throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
+            }
+            data.grammar = json_schema_to_grammar(inputs.json_schema);
+        } else {
+            data.grammar = inputs.grammar;
+        }
+    }
+
+    return data;
+}
+
 static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
     if (!builder.syntax().parse_tool_calls) {
         builder.add_content(builder.consume_rest());
@@ -994,6 +1054,18 @@ static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
     parse_prefixed_json_tool_call_array(builder, prefix);
 }
 
+static void common_chat_parse_magistral(common_chat_msg_parser & builder) {
+    builder.try_parse_reasoning("[THINK]", "[/THINK]");
+
+    if (!builder.syntax().parse_tool_calls) {
+        builder.add_content(builder.consume_rest());
+        return;
+    }
+
+    static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
+    parse_prefixed_json_tool_call_array(builder, prefix);
+}
+
 static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
     common_chat_params data;
 
@@ -2702,6 +2774,10 @@ static common_chat_params common_chat_templates_apply_jinja(
         return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
     }
 
+    if (src.find("[THINK]") != std::string::npos && src.find("[/THINK]") != std::string::npos) {
+        return common_chat_params_init_magistral(tmpl, params);
+    }
+
     // Plain handler (no tools)
     if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
         return common_chat_params_init_without_tools(tmpl, params);
@@ -2802,6 +2878,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
         case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
             common_chat_parse_mistral_nemo(builder);
             break;
+        case COMMON_CHAT_FORMAT_MAGISTRAL:
+            common_chat_parse_magistral(builder);
+            break;
         case COMMON_CHAT_FORMAT_LLAMA_3_X:
             common_chat_parse_llama_3_1(builder);
             break;
index 3c277e15eba7fb8adfa06f2094773fb2a62f37bf..a1afe574bd0cade3bbe085ca2740752455fd9c2c 100644 (file)
@@ -101,6 +101,7 @@ enum common_chat_format {
     COMMON_CHAT_FORMAT_CONTENT_ONLY,
     COMMON_CHAT_FORMAT_GENERIC,
     COMMON_CHAT_FORMAT_MISTRAL_NEMO,
+    COMMON_CHAT_FORMAT_MAGISTRAL,
     COMMON_CHAT_FORMAT_LLAMA_3_X,
     COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
     COMMON_CHAT_FORMAT_DEEPSEEK_R1,
index 9cd67e3ef49d375c380c440134426214a61a07c7..52e23b5ac61f5f5df1617985b34a29b3cf02568d 100644 (file)
@@ -411,6 +411,7 @@ const common_chat_msg message_assist_thoughts_unparsed_md         = simple_assis
 const common_chat_msg message_assist_thoughts_unparsed_md_partial = simple_assist_msg("<think>I'm\nthinking</think>Hello, world!\nWhat's up?\n```json\n{}");
 
 const common_chat_msg message_assist_thoughts_unparsed_r7b       = simple_assist_msg("<|START_THINKING|>I'm\nthinking<|END_THINKING|>Hello, world!\nWhat's up?");
+const common_chat_msg message_assist_thoughts_unparsed_magistral = simple_assist_msg("[THINK]raisonnement[/THINK]RĂ©ponse");
 const common_chat_msg message_assist_thoughts                    = simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking");
 const common_chat_msg message_assist_thoughts_unopened_unparsed  = simple_assist_msg("I'm\nthinking</think>Hello, world!\nWhat's up?");
 const common_chat_msg message_assist_thoughts_no_content         = simple_assist_msg("", "I'm\nthinking");
@@ -745,6 +746,17 @@ static void test_template_output_parsers() {
             tmpls.get(), end_tokens, message_assist_call_id, tools,
             "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
     }
+    {
+        assert_msg_equals(
+            simple_assist_msg("RĂ©ponse", "raisonnement"),
+            common_chat_parse(
+                message_assist_thoughts_unparsed_magistral.content,
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_MAGISTRAL,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
+                }));
+    }
     {
         auto tmpls = read_templates("models/templates/Qwen-QwQ-32B.jinja");
         std::vector<std::string> end_tokens{ "<|im_end|>" };