const common_chat_msg message_assist_json_content =
simple_assist_msg("{\n \"response\": \"Hello, world!\\nWhat's up?\"\n}");
-struct delta_data {
- std::string delta;
- common_chat_params params;
-};
-
-static delta_data init_delta(const struct common_chat_templates * tmpls,
- const std::vector<std::string> & end_tokens,
- const common_chat_msg & user_message,
- const common_chat_msg & delta_message,
- const std::vector<common_chat_tool> & tools,
- const common_chat_tool_choice & tool_choice) {
- common_chat_templates_inputs inputs;
- inputs.parallel_tool_calls = true;
- inputs.messages.push_back(user_message);
- inputs.tools = tools;
- inputs.tool_choice = tool_choice;
- auto params_prefix = common_chat_templates_apply(tmpls, inputs);
-
- inputs.messages.push_back(delta_message);
- inputs.add_generation_prompt = false;
- auto params_full = common_chat_templates_apply(tmpls, inputs);
-
- std::string prefix = params_prefix.prompt;
- std::string full = params_full.prompt;
-
- if (full == prefix) {
- throw std::runtime_error("Full message is the same as the prefix");
- }
-
- size_t common_prefix_length = 0;
- for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
- if (prefix[i] != full[i]) {
- break;
- }
- if (prefix[i] == '<') {
- // DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
- // but it removes thinking tags for past messages.
- // The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
- continue;
- }
- common_prefix_length = i + 1;
- }
- auto delta = full.substr(common_prefix_length);
-
- // Strip end tokens
- for (const auto & end_token : end_tokens) {
- // rfind to find the last occurrence
- auto pos = delta.rfind(end_token);
- if (pos != std::string::npos) {
- delta = delta.substr(0, pos);
- break;
- }
- }
- return { delta, params_full };
-}
-
-/*
- Applies the template to 1 user message w/ add_generation_prompt=true, then w/ the test message w/ add_generation_prompt=false,
- gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
- the parsed message is the same as the test_message
-*/
-static void test_templates(const struct common_chat_templates * tmpls,
- const std::vector<std::string> & end_tokens,
- const common_chat_msg & test_message,
- const std::vector<common_chat_tool> & tools = {},
- const std::string & expected_delta = "",
- bool expect_grammar_triggered = true,
- bool test_grammar_if_triggered = true,
- common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE,
- bool ignore_whitespace_differences = false) {
- common_chat_msg user_message;
- user_message.role = "user";
- user_message.content = "Hello, world!";
-
- common_chat_templates_inputs inputs_tools;
- inputs_tools.messages = { message_user };
- inputs_tools.tools = { special_function_tool };
-
- common_chat_params params = common_chat_templates_apply(tmpls, inputs_tools);
-
- for (const auto & tool_choice :
- std::vector<common_chat_tool_choice>{ COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED }) {
- auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice);
- if (!expected_delta.empty()) {
- if (ignore_whitespace_differences) {
- assert_equals(string_strip(expected_delta), string_strip(data.delta));
- } else {
- assert_equals(expected_delta, data.delta);
- }
- }
-
- if (expect_grammar_triggered) {
- // TODO @ngxson : refactor common_chat_parse to avoid passing format/reasoning_format every time
- common_chat_parser_params parser_params;
- parser_params.format = data.params.format;
- parser_params.reasoning_format = reasoning_format;
- if (!parser_params.parser.empty()) {
- parser_params.parser = common_peg_arena();
- parser_params.parser.load(params.parser);
- }
- const auto msg = common_chat_parse(data.delta, /* is_partial= */ false, parser_params);
- assert_msg_equals(test_message, msg, ignore_whitespace_differences);
- }
-
- if (!test_message.tool_calls.empty()) {
- GGML_ASSERT(!data.params.grammar.empty());
- }
- if (!data.params.grammar.empty()) {
- auto grammar = build_grammar(data.params.grammar);
- if (!grammar) {
- throw std::runtime_error("Failed to build grammar");
- }
- auto earliest_trigger_pos = std::string::npos;
- auto constrained = data.delta;
- for (const auto & trigger : data.params.grammar_triggers) {
- size_t pos = std::string::npos;
- std::smatch match;
- switch (trigger.type) {
- case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
- {
- const auto & word = trigger.value;
- pos = constrained.find(word);
- break;
- }
- case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
- {
- const auto & pattern = std::regex(trigger.value);
- if (std::regex_search(constrained, match, pattern)) {
- pos = match.position(pattern.mark_count());
- }
- break;
- }
- case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
- {
- const auto & pattern = trigger.value;
- if (std::regex_match(constrained, match, std::regex(pattern))) {
- auto mpos = std::string::npos;
- for (size_t i = 1; i < match.size(); ++i) {
- if (match[i].length() > 0) {
- mpos = match.position(i);
- break;
- }
- }
- if (mpos == std::string::npos) {
- mpos = match.position(0);
- }
- pos = mpos;
- }
- break;
- }
- default:
- throw std::runtime_error("Unknown trigger type");
- }
- if (pos == std::string::npos) {
- continue;
- }
- if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) {
- earliest_trigger_pos = pos;
- }
- }
- auto grammar_triggered = false;
- if (earliest_trigger_pos != std::string::npos) {
- constrained = constrained.substr(earliest_trigger_pos);
- grammar_triggered = true;
- }
- if (data.params.grammar_lazy) {
- assert_equals(expect_grammar_triggered, grammar_triggered);
- }
-
- if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) {
- throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
- "\n\nConstrained: " + constrained + "\n\nGrammar: " + data.params.grammar);
- }
- }
- }
-}
-
-/**
- * Test if streaming=true is consistent with streaming=false for given partial parser
- * Also test if there is any problem with partial message
- */
-template <typename T>
-static void test_parser_with_streaming(const common_chat_msg & expected, const std::string & raw_message, T parse_msg) {
- constexpr auto utf8_truncate_safe_len = [](const std::string_view s) -> size_t {
- auto len = s.size();
- if (len == 0) {
- return 0;
- }
- auto i = len;
- for (size_t back = 0; back < 4 && i > 0; ++back) {
- --i;
- unsigned char c = s[i];
- if ((c & 0x80) == 0) {
- return len;
- }
- if ((c & 0xC0) == 0xC0) {
- size_t expected_len = 0;
- if ((c & 0xE0) == 0xC0) {
- expected_len = 2;
- } else if ((c & 0xF0) == 0xE0) {
- expected_len = 3;
- } else if ((c & 0xF8) == 0xF0) {
- expected_len = 4;
- } else {
- return i;
- }
- if (len - i >= expected_len) {
- return len;
- }
- return i;
- }
- }
- return len - std::min(len, size_t(3));
- };
- constexpr auto utf8_truncate_safe_view = [utf8_truncate_safe_len](const std::string_view s) {
- return s.substr(0, utf8_truncate_safe_len(s));
- };
-
- auto merged = simple_assist_msg("");
- auto last_msg = parse_msg("");
- for (size_t i = 1; i <= raw_message.size(); ++i) {
- auto curr_msg = parse_msg(std::string(utf8_truncate_safe_view(std::string_view(raw_message).substr(0, i))));
- if (curr_msg == simple_assist_msg("")) {
- continue;
- }
- LOG_INF("Streaming msg: %s\n", common_chat_msgs_to_json_oaicompat({ curr_msg }).dump().c_str());
- for (auto diff : common_chat_msg_diff::compute_diffs(last_msg, curr_msg)) {
- LOG_INF("Streaming diff: %s\n", common_chat_msg_diff_to_json_oaicompat(diff).dump().c_str());
- if (!diff.reasoning_content_delta.empty()) {
- merged.reasoning_content += diff.reasoning_content_delta;
- }
- if (!diff.content_delta.empty()) {
- merged.content += diff.content_delta;
- }
- if (diff.tool_call_index != std::string::npos) {
- if (!diff.tool_call_delta.name.empty()) {
- merged.tool_calls.push_back({ diff.tool_call_delta.name, "", "" });
- }
- if (!diff.tool_call_delta.arguments.empty()) {
- GGML_ASSERT(!merged.tool_calls.empty());
- merged.tool_calls.back().arguments += diff.tool_call_delta.arguments;
- }
- }
- LOG_INF("Streaming merged: %s\n", common_chat_msgs_to_json_oaicompat({ merged }).dump().c_str());
- }
- assert_msg_equals(curr_msg, merged, true);
- last_msg = curr_msg;
- }
- assert_msg_equals(expected, parse_msg(raw_message), true);
- assert_msg_equals(expected, merged, true);
-}
-
// Use for PEG parser implementations
struct peg_test_case {
common_chat_templates_inputs params;
}
}
+// Test the developer role to system workaround with a simple mock template
+static void test_developer_role_to_system_workaround() {
+ LOG_DBG("%s\n", __func__);
+
+ // Simple mock template that supports system role
+ const std::string mock_template =
+ "{%- for message in messages -%}\n"
+ " {{- '<|' + message.role + '|>' + message.content + '<|end|>' -}}\n"
+ "{%- endfor -%}\n"
+ "{%- if add_generation_prompt -%}\n"
+ " {{- '<|assistant|>' -}}\n"
+ "{%- endif -%}";
+
+ auto tmpls = common_chat_templates_ptr(common_chat_templates_init(/* model= */ nullptr, mock_template));
+
+ // Test case 1: Developer message - should be changed to system
+ // After simplification we only test this case
+ {
+ common_chat_templates_inputs inputs;
+ common_chat_msg developer_msg;
+ developer_msg.role = "developer";
+ developer_msg.content = "You are a helpful developer assistant.";
+ inputs.messages = { developer_msg };
+ inputs.add_generation_prompt = false;
+
+ auto params = common_chat_templates_apply(tmpls.get(), inputs);
+
+ // The developer role should have been changed to system
+ if (params.prompt.find("<|developer|>") != std::string::npos) {
+ throw std::runtime_error("Test failed: developer role was not changed to system");
+ }
+ if (params.prompt.find("<|system|>You are a helpful developer assistant.<|end|>") == std::string::npos) {
+ throw std::runtime_error("Test failed: system message not found in output");
+ }
+ LOG_ERR("Test 1 passed: developer role changed to system\n");
+ }
+}
+
static void test_msg_diffs_compute() {
LOG_DBG("%s\n", __func__);
{
test_msg_diffs_compute();
test_msgs_oaicompat_json_conversion();
test_tools_oaicompat_json_conversion();
+ test_developer_role_to_system_workaround();
test_template_output_peg_parsers(detailed_debug);
std::cout << "\n[chat] All tests passed!" << '\n';
}