diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 72d5fce..e0dcbbf 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -144,7 +144,45 @@ func toResponsesInputItems(msg Message) []map[string]interface{} { case "system", "developer", "user": return []map[string]interface{}{responsesMessageItem(role, msg.Content, "input_text")} case "assistant": - return []map[string]interface{}{responsesMessageItem(role, msg.Content, "output_text")} + items := make([]map[string]interface{}, 0, 1+len(msg.ToolCalls)) + if strings.TrimSpace(msg.Content) != "" || len(msg.ToolCalls) == 0 { + items = append(items, responsesMessageItem(role, msg.Content, "output_text")) + } + for _, tc := range msg.ToolCalls { + callID := strings.TrimSpace(tc.ID) + if callID == "" { + continue + } + name := strings.TrimSpace(tc.Name) + argsRaw := "" + if tc.Function != nil { + if strings.TrimSpace(tc.Function.Name) != "" { + name = strings.TrimSpace(tc.Function.Name) + } + argsRaw = strings.TrimSpace(tc.Function.Arguments) + } + if name == "" { + continue + } + if argsRaw == "" { + argsJSON, err := json.Marshal(tc.Arguments) + if err != nil { + argsRaw = "{}" + } else { + argsRaw = string(argsJSON) + } + } + items = append(items, map[string]interface{}{ + "type": "function_call", + "call_id": callID, + "name": name, + "arguments": argsRaw, + }) + } + if len(items) == 0 { + return []map[string]interface{}{responsesMessageItem(role, msg.Content, "output_text")} + } + return items case "tool": if strings.TrimSpace(msg.ToolCallID) == "" { return []map[string]interface{}{responsesMessageItem("user", msg.Content, "input_text")} diff --git a/pkg/providers/provider_test.go b/pkg/providers/provider_test.go index 70bc61b..6a73f49 100644 --- a/pkg/providers/provider_test.go +++ b/pkg/providers/provider_test.go @@ -135,6 +135,41 @@ func TestToResponsesInputItems_AssistantUsesOutputText(t *testing.T) { } } +func TestToResponsesInputItems_AssistantPreservesToolCalls(t *testing.T) { + items := toResponsesInputItems(Message{ + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + ID: "call_abc", + Name: "exec_command", + Arguments: map[string]interface{}{ + "cmd": "pwd", + }, + }, + }, + }) + if len(items) != 1 { + t.Fatalf("expected 1 item, got %d", len(items)) + } + gotType, _ := items[0]["type"].(string) + if gotType != "function_call" { + t.Fatalf("item type = %q, want function_call", gotType) + } + gotCallID, _ := items[0]["call_id"].(string) + if gotCallID != "call_abc" { + t.Fatalf("call_id = %q, want call_abc", gotCallID) + } + gotName, _ := items[0]["name"].(string) + if gotName != "exec_command" { + t.Fatalf("name = %q, want exec_command", gotName) + } + gotArgs, _ := items[0]["arguments"].(string) + if !strings.Contains(gotArgs, "\"cmd\":\"pwd\"") { + t.Fatalf("arguments = %q, want serialized cmd", gotArgs) + } +} + func containsFunctionCallMarkup(s string) bool { return len(s) > 0 && (strings.Contains(s, "") || strings.Contains(s, "")) }