]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
`server`: fix format of streamed tool call deltas (diff name, fix id location) (...
authorOlivier Chafik <redacted>
Mon, 26 May 2025 13:56:49 +0000 (06:56 -0700)
committerGitHub <redacted>
Mon, 26 May 2025 13:56:49 +0000 (14:56 +0100)
* fix deltas of tool_call.function.name

* fix tool_call.id (was in tool_call.function.id!) + add function type

* add tool_call.type

* populate empty tool_call.function.arguments on first delta

common/chat.cpp
tests/test-chat.cpp
tools/server/tests/utils.py

index c2379f669dc89b13470cf8514f46bab36aae1a1b..2e6a964bbeeceefbb568af185db8b8b432d4bb15 100644 (file)
@@ -106,9 +106,9 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
         if (!args_diff.empty() || pref.id != newf.id) {
             auto & diff = diffs.emplace_back();
             diff.tool_call_index = idx;
-            diff.tool_call_delta.name = newf.name;
             if (pref.id != newf.id) {
                 diff.tool_call_delta.id = newf.id;
+                diff.tool_call_delta.name = newf.name;
             }
             diff.tool_call_delta.arguments = args_diff;
         }
@@ -392,22 +392,19 @@ template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_di
         delta["content"] = diff.content_delta;
     }
     if (diff.tool_call_index != std::string::npos) {
+        json tool_call;
+        tool_call["index"] = diff.tool_call_index;
+        if (!diff.tool_call_delta.id.empty()) {
+            tool_call["id"] = diff.tool_call_delta.id;
+            tool_call["type"] = "function";
+        }
         json function = json::object();
         if (!diff.tool_call_delta.name.empty()) {
             function["name"] = diff.tool_call_delta.name;
         }
-        if (!diff.tool_call_delta.id.empty()) {
-            function["id"] = diff.tool_call_delta.id;
-        }
-        if (!diff.tool_call_delta.arguments.empty()) {
-            function["arguments"] = diff.tool_call_delta.arguments;
-        }
-        delta["tool_calls"] = json::array({
-            json {
-                {"index", diff.tool_call_index},
-                {"function", function}
-            }
-        });
+        function["arguments"] = diff.tool_call_delta.arguments;
+        tool_call["function"] = function;
+        delta["tool_calls"] = json::array({tool_call});
     }
     return delta;
 }
index fb048022a06c4b56de502c887b015f3d3ec4ec19..5f542f022417b56b078205605cb1433846a01f3d 100644 (file)
@@ -1356,8 +1356,7 @@ static void test_msg_diffs_compute() {
 
         common_chat_msg_diff diff12;
         diff12.tool_call_index = 0;
-        diff12.tool_call_delta.name = "special_function";
-        // Note: id doesnt change here.
+        // Note: neither id nor name change here.
         diff12.tool_call_delta.arguments = "g1\": 1}";
 
         assert_equals(
index 11672f515df1d7e9598f96b69cb7b29e9085ef4a..f7e1b3b3b7b8ea053b6c0f2c0dc0e24308547678 100644 (file)
@@ -328,6 +328,10 @@ class ServerProcess:
                     if 'function' not in tc:
                         raise ValueError(f"Expected function type, got {tc['type']}")
                     if tc['index'] >= len(tool_calls):
+                        assert 'id' in tc
+                        assert tc.get('type') == 'function'
+                        assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \
+                            f"Expected function call with name, got {tc.get('function')}"
                         tool_calls.append(dict(
                             id="",
                             type="function",
@@ -340,10 +344,10 @@ class ServerProcess:
                     if tc.get('id') is not None:
                         tool_call['id'] = tc['id']
                     fct = tc['function']
+                    assert 'id' not in fct, f"Function call should not have id: {fct}"
                     if fct.get('name') is not None:
-                        tool_call['function']['name'] = fct['name']
+                        tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name']
                     if fct.get('arguments') is not None:
-                        assert len(fct['arguments']) > 0, f'Expected non empty arguments delta!'
                         tool_call['function']['arguments'] += fct['arguments']
 
             print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')