mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-05-05 21:47:30 +08:00
fix: enforce strict tool-call pairing across chat paths
This commit is contained in:
@@ -442,7 +442,10 @@ func (al *AgentLoop) GetSessionHistory(sessionKey string) []providers.Message {
|
||||
}
|
||||
|
||||
func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
|
||||
msg.SessionKey = "main"
|
||||
msg.SessionKey = strings.TrimSpace(msg.SessionKey)
|
||||
if msg.SessionKey == "" {
|
||||
msg.SessionKey = "main"
|
||||
}
|
||||
unlock := al.lockSessionRun(msg.SessionKey)
|
||||
defer unlock()
|
||||
// Add message preview to log
|
||||
@@ -908,20 +911,6 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
if strings.Contains(errMsg, "no tool call found for function call output") {
|
||||
logger.WarnCF("agent", "System message hit orphan tool-call chain, retry with fresh context", map[string]interface{}{"iteration": iteration, "session": sessionKey})
|
||||
messages = al.contextBuilder.BuildMessages(
|
||||
nil,
|
||||
"",
|
||||
msg.Content,
|
||||
nil,
|
||||
originChannel,
|
||||
originChatID,
|
||||
responseLang,
|
||||
)
|
||||
continue
|
||||
}
|
||||
logger.ErrorCF("agent", "LLM call failed in system message",
|
||||
map[string]interface{}{
|
||||
"iteration": iteration,
|
||||
|
||||
118
pkg/agent/loop_session_regression_test.go
Normal file
118
pkg/agent/loop_session_regression_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"clawgo/pkg/bus"
|
||||
"clawgo/pkg/config"
|
||||
"clawgo/pkg/providers"
|
||||
)
|
||||
|
||||
type recordingProvider struct {
|
||||
mu sync.Mutex
|
||||
calls [][]providers.Message
|
||||
responses []providers.LLMResponse
|
||||
}
|
||||
|
||||
func (p *recordingProvider) Chat(_ context.Context, messages []providers.Message, _ []providers.ToolDefinition, _ string, _ map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
cp := make([]providers.Message, len(messages))
|
||||
copy(cp, messages)
|
||||
p.calls = append(p.calls, cp)
|
||||
if len(p.responses) == 0 {
|
||||
resp := providers.LLMResponse{Content: "ok", FinishReason: "stop"}
|
||||
return &resp, nil
|
||||
}
|
||||
resp := p.responses[0]
|
||||
p.responses = p.responses[1:]
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func (p *recordingProvider) GetDefaultModel() string { return "test-model" }
|
||||
|
||||
func setupLoop(t *testing.T, rp *recordingProvider) *AgentLoop {
|
||||
t.Helper()
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Workspace = filepath.Join(t.TempDir(), "workspace")
|
||||
cfg.Agents.Defaults.MaxToolIterations = 2
|
||||
cfg.Agents.Defaults.ContextCompaction.Enabled = false
|
||||
return NewAgentLoop(cfg, bus.NewMessageBus(), rp, nil)
|
||||
}
|
||||
|
||||
func lastUserContent(msgs []providers.Message) string {
|
||||
for i := len(msgs) - 1; i >= 0; i-- {
|
||||
if msgs[i].Role == "user" {
|
||||
return msgs[i].Content
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func containsUserContent(msgs []providers.Message, needle string) bool {
|
||||
for _, m := range msgs {
|
||||
if m.Role == "user" && m.Content == needle {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestProcessDirect_UsesCallerSessionKey(t *testing.T) {
|
||||
rp := &recordingProvider{}
|
||||
loop := setupLoop(t, rp)
|
||||
|
||||
if _, err := loop.ProcessDirect(context.Background(), "from-session-a", "session-a"); err != nil {
|
||||
t.Fatalf("ProcessDirect session-a failed: %v", err)
|
||||
}
|
||||
if _, err := loop.ProcessDirect(context.Background(), "from-session-b", "session-b"); err != nil {
|
||||
t.Fatalf("ProcessDirect session-b failed: %v", err)
|
||||
}
|
||||
|
||||
if len(rp.calls) != 2 {
|
||||
t.Fatalf("expected 2 provider calls, got %d", len(rp.calls))
|
||||
}
|
||||
second := rp.calls[1]
|
||||
if got := lastUserContent(second); got != "from-session-b" {
|
||||
t.Fatalf("unexpected last user content in second call: %q", got)
|
||||
}
|
||||
if containsUserContent(second, "from-session-a") {
|
||||
t.Fatalf("session-a message leaked into session-b history")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessSystemMessage_UsesOriginSessionKey(t *testing.T) {
|
||||
rp := &recordingProvider{}
|
||||
loop := setupLoop(t, rp)
|
||||
|
||||
sys := bus.InboundMessage{Channel: "system", SenderID: "cron", ChatID: "telegram:chat-1", Content: "system task"}
|
||||
if _, err := loop.processMessage(context.Background(), sys); err != nil {
|
||||
t.Fatalf("processMessage(system) failed: %v", err)
|
||||
}
|
||||
if _, err := loop.ProcessDirect(context.Background(), "follow-up", "telegram:chat-1"); err != nil {
|
||||
t.Fatalf("ProcessDirect follow-up failed: %v", err)
|
||||
}
|
||||
|
||||
if len(rp.calls) != 2 {
|
||||
t.Fatalf("expected 2 provider calls, got %d", len(rp.calls))
|
||||
}
|
||||
second := rp.calls[1]
|
||||
want := "[System: cron] " + rewriteSystemMessageContent("system task", loop.systemRewriteTemplate)
|
||||
if !containsUserContent(second, want) {
|
||||
t.Fatalf("expected system marker in follow-up history, want=%q got=%v", want, summarizeUsers(second))
|
||||
}
|
||||
}
|
||||
|
||||
func summarizeUsers(msgs []providers.Message) []string {
|
||||
out := []string{}
|
||||
for _, m := range msgs {
|
||||
if m.Role == "user" {
|
||||
out = append(out, fmt.Sprintf("%q", m.Content))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -66,19 +66,6 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
preview := previewResponseBody(body)
|
||||
if statusCode == http.StatusBadRequest && strings.Contains(strings.ToLower(preview), "no tool call found for function call output") {
|
||||
// Retry once with sanitized history to avoid orphaned tool outputs causing hard-fail.
|
||||
safeMessages := sanitizeResponsesRetryMessages(messages)
|
||||
body2, status2, ctype2, err2 := p.callResponses(ctx, safeMessages, tools, model, options)
|
||||
if err2 == nil && status2 == http.StatusOK && json.Valid(body2) {
|
||||
logger.InfoCF("provider", "Recovered responses 400 by sanitizing tool outputs", map[string]interface{}{"messages_before": len(messages), "messages_after": len(safeMessages)})
|
||||
return parseResponsesAPIResponse(body2)
|
||||
}
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status2, ctype2, previewResponseBody(body2))
|
||||
}
|
||||
return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, preview)
|
||||
}
|
||||
if !json.Valid(body) {
|
||||
@@ -221,22 +208,6 @@ func toChatCompletionsContent(msg Message) []map[string]interface{} {
|
||||
return content
|
||||
}
|
||||
|
||||
func sanitizeResponsesRetryMessages(messages []Message) []Message {
|
||||
out := make([]Message, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
if strings.EqualFold(strings.TrimSpace(m.Role), "tool") {
|
||||
text := strings.TrimSpace(m.Content)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, Message{Role: "user", Content: "[tool_result_fallback] " + text})
|
||||
continue
|
||||
}
|
||||
out = append(out, m)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func toResponsesInputItems(msg Message) []map[string]interface{} {
|
||||
return toResponsesInputItemsWithState(msg, nil)
|
||||
}
|
||||
@@ -299,12 +270,12 @@ func toResponsesInputItemsWithState(msg Message, pendingCalls map[string]struct{
|
||||
case "tool":
|
||||
callID := strings.TrimSpace(msg.ToolCallID)
|
||||
if callID == "" {
|
||||
return []map[string]interface{}{responsesMessageItem("user", msg.Content, "input_text")}
|
||||
return nil
|
||||
}
|
||||
if pendingCalls != nil {
|
||||
if _, ok := pendingCalls[callID]; !ok {
|
||||
// Avoid invalid orphan/duplicate tool outputs in /responses payload.
|
||||
return []map[string]interface{}{responsesMessageItem("user", msg.Content, "input_text")}
|
||||
// Strict pairing: drop orphan/duplicate tool outputs instead of degrading role.
|
||||
return nil
|
||||
}
|
||||
delete(pendingCalls, callID)
|
||||
}
|
||||
|
||||
42
pkg/providers/http_provider_toolcall_pairing_test.go
Normal file
42
pkg/providers/http_provider_toolcall_pairing_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package providers
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestToResponsesInputItemsWithState_DropsOrphanToolOutputs(t *testing.T) {
|
||||
pending := map[string]struct{}{}
|
||||
|
||||
orphan := Message{Role: "tool", ToolCallID: "call-orphan", Content: "orphan output"}
|
||||
if got := toResponsesInputItemsWithState(orphan, pending); len(got) != 0 {
|
||||
t.Fatalf("expected orphan tool output to be dropped, got: %#v", got)
|
||||
}
|
||||
|
||||
assistant := Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ToolCall{{
|
||||
ID: "call-1",
|
||||
Name: "read",
|
||||
Arguments: map[string]interface{}{
|
||||
"path": "README.md",
|
||||
},
|
||||
}},
|
||||
}
|
||||
items := toResponsesInputItemsWithState(assistant, pending)
|
||||
if len(items) == 0 {
|
||||
t.Fatalf("assistant tool call should produce responses items")
|
||||
}
|
||||
if _, ok := pending["call-1"]; !ok {
|
||||
t.Fatalf("assistant tool call id should be tracked as pending")
|
||||
}
|
||||
|
||||
matched := Message{Role: "tool", ToolCallID: "call-1", Content: "file content"}
|
||||
matchedItems := toResponsesInputItemsWithState(matched, pending)
|
||||
if len(matchedItems) != 1 {
|
||||
t.Fatalf("expected matched tool output item, got %#v", matchedItems)
|
||||
}
|
||||
if matchedItems[0]["type"] != "function_call_output" {
|
||||
t.Fatalf("expected function_call_output item, got %#v", matchedItems[0])
|
||||
}
|
||||
if _, ok := pending["call-1"]; ok {
|
||||
t.Fatalf("matched tool output should clear pending call id")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user