diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index bd3e72e..42d7207 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -442,7 +442,10 @@ func (al *AgentLoop) GetSessionHistory(sessionKey string) []providers.Message { } func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { - msg.SessionKey = "main" + msg.SessionKey = strings.TrimSpace(msg.SessionKey) + if msg.SessionKey == "" { + msg.SessionKey = "main" + } unlock := al.lockSessionRun(msg.SessionKey) defer unlock() // Add message preview to log @@ -908,20 +911,6 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe }) if err != nil { - errMsg := strings.ToLower(err.Error()) - if strings.Contains(errMsg, "no tool call found for function call output") { - logger.WarnCF("agent", "System message hit orphan tool-call chain, retry with fresh context", map[string]interface{}{"iteration": iteration, "session": sessionKey}) - messages = al.contextBuilder.BuildMessages( - nil, - "", - msg.Content, - nil, - originChannel, - originChatID, - responseLang, - ) - continue - } logger.ErrorCF("agent", "LLM call failed in system message", map[string]interface{}{ "iteration": iteration, diff --git a/pkg/agent/loop_session_regression_test.go b/pkg/agent/loop_session_regression_test.go new file mode 100644 index 0000000..7385199 --- /dev/null +++ b/pkg/agent/loop_session_regression_test.go @@ -0,0 +1,118 @@ +package agent + +import ( + "context" + "fmt" + "path/filepath" + "sync" + "testing" + + "clawgo/pkg/bus" + "clawgo/pkg/config" + "clawgo/pkg/providers" +) + +type recordingProvider struct { + mu sync.Mutex + calls [][]providers.Message + responses []providers.LLMResponse +} + +func (p *recordingProvider) Chat(_ context.Context, messages []providers.Message, _ []providers.ToolDefinition, _ string, _ map[string]interface{}) (*providers.LLMResponse, error) { + p.mu.Lock() + defer p.mu.Unlock() + cp := make([]providers.Message, len(messages)) + copy(cp, messages) + p.calls = append(p.calls, cp) + if len(p.responses) == 0 { + resp := providers.LLMResponse{Content: "ok", FinishReason: "stop"} + return &resp, nil + } + resp := p.responses[0] + p.responses = p.responses[1:] + return &resp, nil +} + +func (p *recordingProvider) GetDefaultModel() string { return "test-model" } + +func setupLoop(t *testing.T, rp *recordingProvider) *AgentLoop { + t.Helper() + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = filepath.Join(t.TempDir(), "workspace") + cfg.Agents.Defaults.MaxToolIterations = 2 + cfg.Agents.Defaults.ContextCompaction.Enabled = false + return NewAgentLoop(cfg, bus.NewMessageBus(), rp, nil) +} + +func lastUserContent(msgs []providers.Message) string { + for i := len(msgs) - 1; i >= 0; i-- { + if msgs[i].Role == "user" { + return msgs[i].Content + } + } + return "" +} + +func containsUserContent(msgs []providers.Message, needle string) bool { + for _, m := range msgs { + if m.Role == "user" && m.Content == needle { + return true + } + } + return false +} + +func TestProcessDirect_UsesCallerSessionKey(t *testing.T) { + rp := &recordingProvider{} + loop := setupLoop(t, rp) + + if _, err := loop.ProcessDirect(context.Background(), "from-session-a", "session-a"); err != nil { + t.Fatalf("ProcessDirect session-a failed: %v", err) + } + if _, err := loop.ProcessDirect(context.Background(), "from-session-b", "session-b"); err != nil { + t.Fatalf("ProcessDirect session-b failed: %v", err) + } + + if len(rp.calls) != 2 { + t.Fatalf("expected 2 provider calls, got %d", len(rp.calls)) + } + second := rp.calls[1] + if got := lastUserContent(second); got != "from-session-b" { + t.Fatalf("unexpected last user content in second call: %q", got) + } + if containsUserContent(second, "from-session-a") { + t.Fatalf("session-a message leaked into session-b history") + } +} + +func TestProcessSystemMessage_UsesOriginSessionKey(t *testing.T) { + rp := &recordingProvider{} + loop := setupLoop(t, rp) + + sys := bus.InboundMessage{Channel: "system", SenderID: "cron", ChatID: "telegram:chat-1", Content: "system task"} + if _, err := loop.processMessage(context.Background(), sys); err != nil { + t.Fatalf("processMessage(system) failed: %v", err) + } + if _, err := loop.ProcessDirect(context.Background(), "follow-up", "telegram:chat-1"); err != nil { + t.Fatalf("ProcessDirect follow-up failed: %v", err) + } + + if len(rp.calls) != 2 { + t.Fatalf("expected 2 provider calls, got %d", len(rp.calls)) + } + second := rp.calls[1] + want := "[System: cron] " + rewriteSystemMessageContent("system task", loop.systemRewriteTemplate) + if !containsUserContent(second, want) { + t.Fatalf("expected system marker in follow-up history, want=%q got=%v", want, summarizeUsers(second)) + } +} + +func summarizeUsers(msgs []providers.Message) []string { + out := []string{} + for _, m := range msgs { + if m.Role == "user" { + out = append(out, fmt.Sprintf("%q", m.Content)) + } + } + return out +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 0342c71..697e4d8 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -66,19 +66,6 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too } if statusCode != http.StatusOK { preview := previewResponseBody(body) - if statusCode == http.StatusBadRequest && strings.Contains(strings.ToLower(preview), "no tool call found for function call output") { - // Retry once with sanitized history to avoid orphaned tool outputs causing hard-fail. - safeMessages := sanitizeResponsesRetryMessages(messages) - body2, status2, ctype2, err2 := p.callResponses(ctx, safeMessages, tools, model, options) - if err2 == nil && status2 == http.StatusOK && json.Valid(body2) { - logger.InfoCF("provider", "Recovered responses 400 by sanitizing tool outputs", map[string]interface{}{"messages_before": len(messages), "messages_after": len(safeMessages)}) - return parseResponsesAPIResponse(body2) - } - if err2 != nil { - return nil, err2 - } - return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status2, ctype2, previewResponseBody(body2)) - } return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, preview) } if !json.Valid(body) { @@ -221,22 +208,6 @@ func toChatCompletionsContent(msg Message) []map[string]interface{} { return content } -func sanitizeResponsesRetryMessages(messages []Message) []Message { - out := make([]Message, 0, len(messages)) - for _, m := range messages { - if strings.EqualFold(strings.TrimSpace(m.Role), "tool") { - text := strings.TrimSpace(m.Content) - if text == "" { - continue - } - out = append(out, Message{Role: "user", Content: "[tool_result_fallback] " + text}) - continue - } - out = append(out, m) - } - return out -} - func toResponsesInputItems(msg Message) []map[string]interface{} { return toResponsesInputItemsWithState(msg, nil) } @@ -299,12 +270,12 @@ func toResponsesInputItemsWithState(msg Message, pendingCalls map[string]struct{ case "tool": callID := strings.TrimSpace(msg.ToolCallID) if callID == "" { - return []map[string]interface{}{responsesMessageItem("user", msg.Content, "input_text")} + return nil } if pendingCalls != nil { if _, ok := pendingCalls[callID]; !ok { - // Avoid invalid orphan/duplicate tool outputs in /responses payload. - return []map[string]interface{}{responsesMessageItem("user", msg.Content, "input_text")} + // Strict pairing: drop orphan/duplicate tool outputs instead of degrading role. + return nil } delete(pendingCalls, callID) } diff --git a/pkg/providers/http_provider_toolcall_pairing_test.go b/pkg/providers/http_provider_toolcall_pairing_test.go new file mode 100644 index 0000000..299e6c7 --- /dev/null +++ b/pkg/providers/http_provider_toolcall_pairing_test.go @@ -0,0 +1,42 @@ +package providers + +import "testing" + +func TestToResponsesInputItemsWithState_DropsOrphanToolOutputs(t *testing.T) { + pending := map[string]struct{}{} + + orphan := Message{Role: "tool", ToolCallID: "call-orphan", Content: "orphan output"} + if got := toResponsesInputItemsWithState(orphan, pending); len(got) != 0 { + t.Fatalf("expected orphan tool output to be dropped, got: %#v", got) + } + + assistant := Message{ + Role: "assistant", + ToolCalls: []ToolCall{{ + ID: "call-1", + Name: "read", + Arguments: map[string]interface{}{ + "path": "README.md", + }, + }}, + } + items := toResponsesInputItemsWithState(assistant, pending) + if len(items) == 0 { + t.Fatalf("assistant tool call should produce responses items") + } + if _, ok := pending["call-1"]; !ok { + t.Fatalf("assistant tool call id should be tracked as pending") + } + + matched := Message{Role: "tool", ToolCallID: "call-1", Content: "file content"} + matchedItems := toResponsesInputItemsWithState(matched, pending) + if len(matchedItems) != 1 { + t.Fatalf("expected matched tool output item, got %#v", matchedItems) + } + if matchedItems[0]["type"] != "function_call_output" { + t.Fatalf("expected function_call_output item, got %#v", matchedItems[0]) + } + if _, ok := pending["call-1"]; ok { + t.Fatalf("matched tool output should clear pending call id") + } +}