From b8cf8ad1b11781aefcea03d1a29f03bc508eb11e Mon Sep 17 00:00:00 2001 From: lpf Date: Mon, 11 May 2026 18:40:45 +0800 Subject: [PATCH] fix(provider): preserve reasoning content in chat completions --- pkg/agent/loop.go | 5 +- pkg/providers/openai_compat_adapter.go | 104 ++++++++++++++++++- pkg/providers/openai_compat_provider_test.go | 97 +++++++++++++++++ pkg/providers/types.go | 20 ++-- pkg/session/manager.go | 21 ++-- 5 files changed, 223 insertions(+), 24 deletions(-) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 7c80d53..96f4885 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -944,8 +944,9 @@ func estimateResponseUsage(ctx context.Context, provider providers.LLMProvider, func buildAssistantToolCallMessage(response *providers.LLMResponse) providers.Message { assistantMsg := providers.Message{ - Role: "assistant", - Content: response.Content, + Role: "assistant", + Content: response.Content, + ReasoningContent: response.ReasoningContent, } if response == nil { return assistantMsg diff --git a/pkg/providers/openai_compat_adapter.go b/pkg/providers/openai_compat_adapter.go index cdb7b63..57ec3a9 100644 --- a/pkg/providers/openai_compat_adapter.go +++ b/pkg/providers/openai_compat_adapter.go @@ -11,8 +11,9 @@ func parseOpenAICompatResponse(body []byte) (*LLMResponse, error) { var payload struct { Choices []struct { Message struct { - Content string `json:"content"` - ToolCalls []struct { + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content"` + ToolCalls []struct { ID string `json:"id"` Type string `json:"type"` Function struct { @@ -37,8 +38,9 @@ func parseOpenAICompatResponse(body []byte) (*LLMResponse, error) { } choice := payload.Choices[0] resp := &LLMResponse{ - Content: choice.Message.Content, - FinishReason: choice.FinishReason, + Content: choice.Message.Content, + ReasoningContent: choice.Message.ReasoningContent, + FinishReason: choice.FinishReason, } if payload.Usage.TotalTokens > 0 || payload.Usage.PromptTokens > 0 || payload.Usage.CompletionTokens > 0 { resp.Usage = &UsageInfo{ @@ -170,6 +172,7 @@ func (p *HTTPProvider) buildOpenAICompatChatRequest(messages []Message, tools [] if temperature, ok := float64FromOption(options, "temperature"); ok { requestBody["temperature"] = temperature } + normalizeOpenAICompatThinkingMessages(requestBody) return requestBody } @@ -185,6 +188,9 @@ func openAICompatMessages(messages []Message) []map[string]interface{} { out = append(out, map[string]interface{}{"role": "user", "content": content}) case "assistant": item := map[string]interface{}{"role": "assistant", "content": content} + if reasoning := strings.TrimSpace(msg.ReasoningContent); reasoning != "" { + item["reasoning_content"] = reasoning + } if len(msg.ToolCalls) > 0 { toolCalls := make([]map[string]interface{}, 0, len(msg.ToolCalls)) for _, tc := range msg.ToolCalls { @@ -225,6 +231,96 @@ func openAICompatMessages(messages []Message) []map[string]interface{} { return out } +func normalizeOpenAICompatThinkingMessages(body map[string]interface{}) { + var items []map[string]interface{} + switch raw := body["messages"].(type) { + case []map[string]interface{}: + items = raw + case []interface{}: + items = make([]map[string]interface{}, 0, len(raw)) + for _, item := range raw { + msg, _ := item.(map[string]interface{}) + if msg != nil { + items = append(items, msg) + } + } + } + if len(items) == 0 { + return + } + latestReasoning := "" + hasLatestReasoning := false + for i := range items { + msg := items[i] + if !strings.EqualFold(strings.TrimSpace(fmt.Sprintf("%v", msg["role"])), "assistant") { + continue + } + if raw, ok := msg["reasoning_content"]; ok { + if reasoning := strings.TrimSpace(fmt.Sprintf("%v", raw)); reasoning != "" && reasoning != "" { + latestReasoning = reasoning + hasLatestReasoning = true + } + } + if !assistantMessageHasToolCalls(msg) { + continue + } + existingReasoning := strings.TrimSpace(fmt.Sprintf("%v", msg["reasoning_content"])) + if existingReasoning == "" || existingReasoning == "" { + msg["reasoning_content"] = fallbackAssistantReasoningContent(msg, hasLatestReasoning, latestReasoning) + if reasoning := strings.TrimSpace(fmt.Sprintf("%v", msg["reasoning_content"])); reasoning != "" && reasoning != "" { + latestReasoning = reasoning + hasLatestReasoning = true + } + } + } +} + +func assistantMessageHasToolCalls(msg map[string]interface{}) bool { + switch raw := msg["tool_calls"].(type) { + case []interface{}: + return len(raw) > 0 + case []map[string]interface{}: + return len(raw) > 0 + default: + return false + } +} + +func fallbackAssistantReasoningContent(msg map[string]interface{}, hasLatest bool, latest string) string { + if hasLatest && strings.TrimSpace(latest) != "" { + return latest + } + if text := strings.TrimSpace(fmt.Sprintf("%v", msg["content"])); text != "" && text != "" { + return text + } + switch content := msg["content"].(type) { + case []map[string]interface{}: + return joinAssistantTextParts(content) + case []interface{}: + parts := make([]map[string]interface{}, 0, len(content)) + for _, raw := range content { + part, _ := raw.(map[string]interface{}) + if part != nil { + parts = append(parts, part) + } + } + return joinAssistantTextParts(parts) + default: + return "" + } +} + +func joinAssistantTextParts(parts []map[string]interface{}) string { + texts := make([]string, 0, len(parts)) + for _, part := range parts { + text := strings.TrimSpace(fmt.Sprintf("%v", part["text"])) + if text != "" && text != "" { + texts = append(texts, text) + } + } + return strings.Join(texts, "\n") +} + func openAICompatMessageContent(msg Message) interface{} { if len(msg.ContentParts) == 0 { return msg.Content diff --git a/pkg/providers/openai_compat_provider_test.go b/pkg/providers/openai_compat_provider_test.go index cce6186..4740668 100644 --- a/pkg/providers/openai_compat_provider_test.go +++ b/pkg/providers/openai_compat_provider_test.go @@ -215,3 +215,100 @@ func TestHTTPProviderChatUsesConfiguredChatCompletionsAPI(t *testing.T) { t.Fatalf("usage = %#v, want total_tokens=3", resp.Usage) } } + +func TestParseOpenAICompatResponseCapturesReasoningContent(t *testing.T) { + resp, err := parseOpenAICompatResponse([]byte(`{"choices":[{"message":{"content":"answer","reasoning_content":"hidden chain"},"finish_reason":"stop"}]}`)) + if err != nil { + t.Fatalf("parseOpenAICompatResponse error: %v", err) + } + if resp.ReasoningContent != "hidden chain" { + t.Fatalf("ReasoningContent = %q, want hidden chain", resp.ReasoningContent) + } +} + +func TestOpenAICompatMessagesIncludeReasoningContent(t *testing.T) { + msgs := openAICompatMessages([]Message{{ + Role: "assistant", + Content: "tool plan", + ReasoningContent: "thinking trace", + ToolCalls: []ToolCall{{ + ID: "call_1", + Name: "read_file", + Function: &FunctionCall{ + Name: "read_file", + Arguments: `{"path":"a.txt"}`, + }, + }}, + }}) + if len(msgs) != 1 { + t.Fatalf("messages len = %d", len(msgs)) + } + if got := msgs[0]["reasoning_content"]; got != "thinking trace" { + t.Fatalf("reasoning_content = %#v, want thinking trace", got) + } +} + +func TestNormalizeOpenAICompatThinkingMessagesBackfillsReasoningForToolCalls(t *testing.T) { + body := map[string]interface{}{ + "messages": []map[string]interface{}{ + { + "role": "assistant", + "tool_calls": []map[string]interface{}{ + {"id": "call_1"}, + }, + "content": "thinking content", + }, + }, + } + + normalizeOpenAICompatThinkingMessages(body) + + msgs := body["messages"].([]map[string]interface{}) + if got := msgs[0]["reasoning_content"]; got != "thinking content" { + t.Fatalf("reasoning_content = %#v, want thinking content", got) + } +} + +func TestHTTPProviderChatConfiguredCompatBackfillsReasoningContentForToolHistory(t *testing.T) { + var gotBody map[string]interface{} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := json.NewDecoder(r.Body).Decode(&gotBody); err != nil { + t.Fatalf("decode request: %v", err) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`)) + })) + defer server.Close() + + provider := NewHTTPProvider("openai", "token", server.URL+"/v1", "gpt-5", false, "api_key", 5*time.Second, nil) + provider.responsesAPI = "chat_completions" + + _, err := provider.Chat(t.Context(), []Message{ + {Role: "user", Content: "hello"}, + { + Role: "assistant", + Content: "thinking content", + ToolCalls: []ToolCall{{ + ID: "call_1", + Name: "read_file", + Function: &FunctionCall{ + Name: "read_file", + Arguments: `{"path":"a.txt"}`, + }, + }}, + }, + {Role: "tool", ToolCallID: "call_1", Content: "file body"}, + }, nil, "gpt-5(high)", nil) + if err != nil { + t.Fatalf("Chat error: %v", err) + } + + rawMsgs, _ := gotBody["messages"].([]interface{}) + if len(rawMsgs) < 2 { + t.Fatalf("messages = %#v", gotBody["messages"]) + } + assistant, _ := rawMsgs[1].(map[string]interface{}) + if got := assistant["reasoning_content"]; got != "thinking content" { + t.Fatalf("reasoning_content = %#v, want thinking content", got) + } +} diff --git a/pkg/providers/types.go b/pkg/providers/types.go index 3846538..389369a 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -16,10 +16,11 @@ type FunctionCall struct { } type LLMResponse struct { - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - FinishReason string `json:"finish_reason"` - Usage *UsageInfo `json:"usage,omitempty"` + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + FinishReason string `json:"finish_reason"` + Usage *UsageInfo `json:"usage,omitempty"` } type UsageInfo struct { @@ -29,11 +30,12 @@ type UsageInfo struct { } type Message struct { - Role string `json:"role"` - Content string `json:"content"` - ContentParts []MessageContentPart `json:"content_parts,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ContentParts []MessageContentPart `json:"content_parts,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` } type MessageContentPart struct { diff --git a/pkg/session/manager.go b/pkg/session/manager.go index 06b2543..7bbee6c 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -64,9 +64,10 @@ type openClawEvent struct { Type string `json:"type"` Text string `json:"text,omitempty"` } `json:"content,omitempty"` - ToolCallID string `json:"toolCallId,omitempty"` - ToolName string `json:"toolName,omitempty"` - ToolCalls []providers.ToolCall `json:"toolCalls,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCallID string `json:"toolCallId,omitempty"` + ToolName string `json:"toolName,omitempty"` + ToolCalls []providers.ToolCall `json:"toolCalls,omitempty"` } `json:"message,omitempty"` } @@ -577,9 +578,10 @@ func toOpenClawMessageEvent(msg providers.Message) openClawEvent { Type string `json:"type"` Text string `json:"text,omitempty"` } `json:"content,omitempty"` - ToolCallID string `json:"toolCallId,omitempty"` - ToolName string `json:"toolName,omitempty"` - ToolCalls []providers.ToolCall `json:"toolCalls,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCallID string `json:"toolCallId,omitempty"` + ToolName string `json:"toolName,omitempty"` + ToolCalls []providers.ToolCall `json:"toolCalls,omitempty"` }{ Role: mappedRole, Content: []struct { @@ -588,8 +590,9 @@ func toOpenClawMessageEvent(msg providers.Message) openClawEvent { }{ {Type: "text", Text: msg.Content}, }, - ToolCallID: msg.ToolCallID, - ToolCalls: msg.ToolCalls, + ReasoningContent: msg.ReasoningContent, + ToolCallID: msg.ToolCallID, + ToolCalls: msg.ToolCalls, }, } return e @@ -620,7 +623,7 @@ func fromJSONLLine(line []byte) (providers.Message, bool) { content += part.Text } } - return providers.Message{Role: role, Content: content, ToolCallID: event.Message.ToolCallID, ToolCalls: event.Message.ToolCalls}, true + return providers.Message{Role: role, Content: content, ReasoningContent: event.Message.ReasoningContent, ToolCallID: event.Message.ToolCallID, ToolCalls: event.Message.ToolCalls}, true } func deriveSessionID(key string) string {