]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
common: map developer role to system (#20215)
authorPiotr Wilkin (ilintar) <redacted>
Mon, 9 Mar 2026 13:25:11 +0000 (14:25 +0100)
committerGitHub <redacted>
Mon, 9 Mar 2026 13:25:11 +0000 (14:25 +0100)
* Map developer role to system
* Simplify

common/chat.cpp
tests/test-chat.cpp

index 29d2e5fd12d5b673ce02e4455be3d968821841f2..d86bad462b16d849f9af653df8d01f7894c6bed7 100644 (file)
@@ -1352,6 +1352,17 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
 
 namespace workaround {
 
+static void map_developer_role_to_system(json & messages) {
+    for (auto & message : messages) {
+        if (message.contains("role")) {
+            if (message["role"] == "developer") {
+                message["role"] = "system";
+            }
+        }
+    }
+}
+
+
 // if first message is system and template does not support it, merge it with next message
 static void system_message_not_supported(json & messages) {
     if (!messages.empty() && messages.front().at("role") == "system") {
@@ -1429,6 +1440,10 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
     params.add_bos = tmpls->add_bos;
     params.add_eos = tmpls->add_eos;
 
+    if (src.find("<|channel|>") == std::string::npos) {
+        // map developer to system for all models except for GPT-OSS
+        workaround::map_developer_role_to_system(params.messages);
+    }
     workaround::func_args_not_string(params.messages);
 
     if (!tmpl.original_caps().supports_system_role) {
index 2f83d7c0b1e39e073fda5512ce35879650afaf46..b46a34e9398ecb955d1ceb3face2990d2ebeffde 100644 (file)
@@ -800,258 +800,6 @@ const common_chat_msg message_assist_call_python_lines_unclosed =
 const common_chat_msg message_assist_json_content =
     simple_assist_msg("{\n  \"response\": \"Hello, world!\\nWhat's up?\"\n}");
 
-struct delta_data {
-    std::string        delta;
-    common_chat_params params;
-};
-
-static delta_data init_delta(const struct common_chat_templates *  tmpls,
-                             const std::vector<std::string> &      end_tokens,
-                             const common_chat_msg &               user_message,
-                             const common_chat_msg &               delta_message,
-                             const std::vector<common_chat_tool> & tools,
-                             const common_chat_tool_choice &       tool_choice) {
-    common_chat_templates_inputs inputs;
-    inputs.parallel_tool_calls = true;
-    inputs.messages.push_back(user_message);
-    inputs.tools       = tools;
-    inputs.tool_choice = tool_choice;
-    auto params_prefix = common_chat_templates_apply(tmpls, inputs);
-
-    inputs.messages.push_back(delta_message);
-    inputs.add_generation_prompt = false;
-    auto params_full             = common_chat_templates_apply(tmpls, inputs);
-
-    std::string prefix = params_prefix.prompt;
-    std::string full   = params_full.prompt;
-
-    if (full == prefix) {
-        throw std::runtime_error("Full message is the same as the prefix");
-    }
-
-    size_t common_prefix_length = 0;
-    for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
-        if (prefix[i] != full[i]) {
-            break;
-        }
-        if (prefix[i] == '<') {
-            // DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
-            // but it removes thinking tags for past messages.
-            // The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
-            continue;
-        }
-        common_prefix_length = i + 1;
-    }
-    auto delta = full.substr(common_prefix_length);
-
-    // Strip end tokens
-    for (const auto & end_token : end_tokens) {
-        // rfind to find the last occurrence
-        auto pos = delta.rfind(end_token);
-        if (pos != std::string::npos) {
-            delta = delta.substr(0, pos);
-            break;
-        }
-    }
-    return { delta, params_full };
-}
-
-/*
-  Applies the template to 1 user message w/ add_generation_prompt=true, then w/ the test message w/ add_generation_prompt=false,
-  gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
-  the parsed message is the same as the test_message
-*/
-static void test_templates(const struct common_chat_templates *  tmpls,
-                           const std::vector<std::string> &      end_tokens,
-                           const common_chat_msg &               test_message,
-                           const std::vector<common_chat_tool> & tools                     = {},
-                           const std::string &                   expected_delta            = "",
-                           bool                                  expect_grammar_triggered  = true,
-                           bool                                  test_grammar_if_triggered = true,
-                           common_reasoning_format               reasoning_format = COMMON_REASONING_FORMAT_NONE,
-                           bool                                  ignore_whitespace_differences = false) {
-    common_chat_msg user_message;
-    user_message.role    = "user";
-    user_message.content = "Hello, world!";
-
-    common_chat_templates_inputs inputs_tools;
-    inputs_tools.messages = { message_user };
-    inputs_tools.tools    = { special_function_tool };
-
-    common_chat_params params = common_chat_templates_apply(tmpls, inputs_tools);
-
-    for (const auto & tool_choice :
-         std::vector<common_chat_tool_choice>{ COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED }) {
-        auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice);
-        if (!expected_delta.empty()) {
-            if (ignore_whitespace_differences) {
-                assert_equals(string_strip(expected_delta), string_strip(data.delta));
-            } else {
-                assert_equals(expected_delta, data.delta);
-            }
-        }
-
-        if (expect_grammar_triggered) {
-            // TODO @ngxson : refactor common_chat_parse to avoid passing format/reasoning_format every time
-            common_chat_parser_params parser_params;
-            parser_params.format           = data.params.format;
-            parser_params.reasoning_format = reasoning_format;
-            if (!parser_params.parser.empty()) {
-                parser_params.parser = common_peg_arena();
-                parser_params.parser.load(params.parser);
-            }
-            const auto msg = common_chat_parse(data.delta, /* is_partial= */ false, parser_params);
-            assert_msg_equals(test_message, msg, ignore_whitespace_differences);
-        }
-
-        if (!test_message.tool_calls.empty()) {
-            GGML_ASSERT(!data.params.grammar.empty());
-        }
-        if (!data.params.grammar.empty()) {
-            auto grammar = build_grammar(data.params.grammar);
-            if (!grammar) {
-                throw std::runtime_error("Failed to build grammar");
-            }
-            auto earliest_trigger_pos = std::string::npos;
-            auto constrained          = data.delta;
-            for (const auto & trigger : data.params.grammar_triggers) {
-                size_t      pos = std::string::npos;
-                std::smatch match;
-                switch (trigger.type) {
-                    case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
-                        {
-                            const auto & word = trigger.value;
-                            pos               = constrained.find(word);
-                            break;
-                        }
-                    case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
-                        {
-                            const auto & pattern = std::regex(trigger.value);
-                            if (std::regex_search(constrained, match, pattern)) {
-                                pos = match.position(pattern.mark_count());
-                            }
-                            break;
-                        }
-                    case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
-                        {
-                            const auto & pattern = trigger.value;
-                            if (std::regex_match(constrained, match, std::regex(pattern))) {
-                                auto mpos = std::string::npos;
-                                for (size_t i = 1; i < match.size(); ++i) {
-                                    if (match[i].length() > 0) {
-                                        mpos = match.position(i);
-                                        break;
-                                    }
-                                }
-                                if (mpos == std::string::npos) {
-                                    mpos = match.position(0);
-                                }
-                                pos = mpos;
-                            }
-                            break;
-                        }
-                    default:
-                        throw std::runtime_error("Unknown trigger type");
-                }
-                if (pos == std::string::npos) {
-                    continue;
-                }
-                if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
-                    earliest_trigger_pos = pos;
-                }
-            }
-            auto grammar_triggered = false;
-            if (earliest_trigger_pos != std::string::npos) {
-                constrained       = constrained.substr(earliest_trigger_pos);
-                grammar_triggered = true;
-            }
-            if (data.params.grammar_lazy) {
-                assert_equals(expect_grammar_triggered, grammar_triggered);
-            }
-
-            if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) {
-                throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
-                                         "\n\nConstrained: " + constrained + "\n\nGrammar: " + data.params.grammar);
-            }
-        }
-    }
-}
-
-/**
- * Test if streaming=true is consistent with streaming=false for given partial parser
- * Also test if there is any problem with partial message
- */
-template <typename T>
-static void test_parser_with_streaming(const common_chat_msg & expected, const std::string & raw_message, T parse_msg) {
-    constexpr auto utf8_truncate_safe_len = [](const std::string_view s) -> size_t {
-        auto len = s.size();
-        if (len == 0) {
-            return 0;
-        }
-        auto i = len;
-        for (size_t back = 0; back < 4 && i > 0; ++back) {
-            --i;
-            unsigned char c = s[i];
-            if ((c & 0x80) == 0) {
-                return len;
-            }
-            if ((c & 0xC0) == 0xC0) {
-                size_t expected_len = 0;
-                if ((c & 0xE0) == 0xC0) {
-                    expected_len = 2;
-                } else if ((c & 0xF0) == 0xE0) {
-                    expected_len = 3;
-                } else if ((c & 0xF8) == 0xF0) {
-                    expected_len = 4;
-                } else {
-                    return i;
-                }
-                if (len - i >= expected_len) {
-                    return len;
-                }
-                return i;
-            }
-        }
-        return len - std::min(len, size_t(3));
-    };
-    constexpr auto utf8_truncate_safe_view = [utf8_truncate_safe_len](const std::string_view s) {
-        return s.substr(0, utf8_truncate_safe_len(s));
-    };
-
-    auto merged   = simple_assist_msg("");
-    auto last_msg = parse_msg("");
-    for (size_t i = 1; i <= raw_message.size(); ++i) {
-        auto curr_msg = parse_msg(std::string(utf8_truncate_safe_view(std::string_view(raw_message).substr(0, i))));
-        if (curr_msg == simple_assist_msg("")) {
-            continue;
-        }
-        LOG_INF("Streaming msg: %s\n", common_chat_msgs_to_json_oaicompat({ curr_msg }).dump().c_str());
-        for (auto diff : common_chat_msg_diff::compute_diffs(last_msg, curr_msg)) {
-            LOG_INF("Streaming diff: %s\n", common_chat_msg_diff_to_json_oaicompat(diff).dump().c_str());
-            if (!diff.reasoning_content_delta.empty()) {
-                merged.reasoning_content += diff.reasoning_content_delta;
-            }
-            if (!diff.content_delta.empty()) {
-                merged.content += diff.content_delta;
-            }
-            if (diff.tool_call_index != std::string::npos) {
-                if (!diff.tool_call_delta.name.empty()) {
-                    merged.tool_calls.push_back({ diff.tool_call_delta.name, "", "" });
-                }
-                if (!diff.tool_call_delta.arguments.empty()) {
-                    GGML_ASSERT(!merged.tool_calls.empty());
-                    merged.tool_calls.back().arguments += diff.tool_call_delta.arguments;
-                }
-            }
-            LOG_INF("Streaming merged: %s\n", common_chat_msgs_to_json_oaicompat({ merged }).dump().c_str());
-        }
-        assert_msg_equals(curr_msg, merged, true);
-        last_msg = curr_msg;
-    }
-    assert_msg_equals(expected, parse_msg(raw_message), true);
-    assert_msg_equals(expected, merged, true);
-}
-
 // Use for PEG parser implementations
 struct peg_test_case {
     common_chat_templates_inputs params;
@@ -3019,6 +2767,44 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
     }
 }
 
+// Test the developer role to system workaround with a simple mock template
+static void test_developer_role_to_system_workaround() {
+    LOG_DBG("%s\n", __func__);
+
+    // Simple mock template that supports system role
+    const std::string mock_template =
+        "{%- for message in messages -%}\n"
+        "  {{- '<|' + message.role + '|>' + message.content + '<|end|>' -}}\n"
+        "{%- endfor -%}\n"
+        "{%- if add_generation_prompt -%}\n"
+        "  {{- '<|assistant|>' -}}\n"
+        "{%- endif -%}";
+
+    auto tmpls = common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, mock_template));
+
+    // Test case 1: Developer message - should be changed to system
+    // After simplification we only test this case
+    {
+        common_chat_templates_inputs inputs;
+        common_chat_msg developer_msg;
+        developer_msg.role = "developer";
+        developer_msg.content = "You are a helpful developer assistant.";
+        inputs.messages = { developer_msg };
+        inputs.add_generation_prompt = false;
+
+        auto params = common_chat_templates_apply(tmpls.get(), inputs);
+
+        // The developer role should have been changed to system
+        if (params.prompt.find("<|developer|>") != std::string::npos) {
+            throw std::runtime_error("Test failed: developer role was not changed to system");
+        }
+        if (params.prompt.find("<|system|>You are a helpful developer assistant.<|end|>") == std::string::npos) {
+            throw std::runtime_error("Test failed: system message not found in output");
+        }
+        LOG_ERR("Test 1 passed: developer role changed to system\n");
+    }
+}
+
 static void test_msg_diffs_compute() {
     LOG_DBG("%s\n", __func__);
     {
@@ -3155,6 +2941,7 @@ int main(int argc, char ** argv) {
         test_msg_diffs_compute();
         test_msgs_oaicompat_json_conversion();
         test_tools_oaicompat_json_conversion();
+        test_developer_role_to_system_workaround();
         test_template_output_peg_parsers(detailed_debug);
         std::cout << "\n[chat] All tests passed!" << '\n';
     }