diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 0f653c7..1f34736 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -9,6 +9,7 @@ package agent import ( "context" "encoding/json" + "errors" "fmt" "hash/fnv" "math" @@ -544,6 +545,11 @@ func (al *AgentLoop) processInbound(ctx context.Context, msg bus.InboundMessage) response, err := al.processPlannedMessage(ctx, msg) if err != nil { + if errors.Is(err, context.Canceled) { + al.audit.Record(al.getTrigger(msg), msg.Channel, msg.SessionKey, true, err) + al.appendTaskAudit(taskID, msg, started, err, true) + return + } response = fmt.Sprintf("Error processing message: %v", err) } diff --git a/pkg/agent/session_planner.go b/pkg/agent/session_planner.go index 62963e6..91d4f29 100644 --- a/pkg/agent/session_planner.go +++ b/pkg/agent/session_planner.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "encoding/json" + "errors" "fmt" "math" "os" @@ -27,6 +28,7 @@ type plannedTaskResult struct { Index int Task plannedTask Output string + Err error ErrText string } @@ -143,19 +145,19 @@ func (al *AgentLoop) runPlannedTasks(ctx context.Context, msg bus.InboundMessage subMsg.Metadata["planned_task_index"] = fmt.Sprintf("%d", t.Index) subMsg.Metadata["planned_task_total"] = fmt.Sprintf("%d", len(tasks)) out, err := al.processMessage(ctx, subMsg) - res := plannedTaskResult{Index: index, Task: t, Output: strings.TrimSpace(out)} + res := plannedTaskResult{Index: index, Task: t, Output: strings.TrimSpace(out), Err: err} if err != nil { res.ErrText = err.Error() } results[index] = res progressMu.Lock() completed++ - if res.ErrText != "" { + if res.ErrText != "" && !isPlannedTaskCancellation(ctx, res) { failed++ } snapshotCompleted := completed snapshotFailed := failed - shouldNotify := shouldPublishPlannedTaskProgress(len(tasks), snapshotCompleted, res, milestones, notified) + shouldNotify := shouldPublishPlannedTaskProgress(ctx, len(tasks), snapshotCompleted, res, milestones, notified) if shouldNotify && res.ErrText == "" { notified[snapshotCompleted] = struct{}{} } @@ -167,6 +169,9 @@ func (al *AgentLoop) runPlannedTasks(ctx context.Context, msg bus.InboundMessage }(i, task) } wg.Wait() + if err := ctx.Err(); err != nil { + return "", err + } var b strings.Builder b.WriteString(fmt.Sprintf("已自动拆解为 %d 个任务并执行:\n\n", len(results))) for _, r := range results { @@ -205,10 +210,13 @@ func plannedProgressMilestones(total int) []int { return out } -func shouldPublishPlannedTaskProgress(total, completed int, res plannedTaskResult, milestones []int, notified map[int]struct{}) bool { +func shouldPublishPlannedTaskProgress(ctx context.Context, total, completed int, res plannedTaskResult, milestones []int, notified map[int]struct{}) bool { if total <= 1 { return false } + if isPlannedTaskCancellation(ctx, res) { + return false + } if strings.TrimSpace(res.ErrText) != "" { return true } @@ -227,6 +235,16 @@ func shouldPublishPlannedTaskProgress(total, completed int, res plannedTaskResul return false } +func isPlannedTaskCancellation(ctx context.Context, res plannedTaskResult) bool { + if res.Err != nil && errors.Is(res.Err, context.Canceled) { + return true + } + if strings.EqualFold(strings.TrimSpace(res.ErrText), context.Canceled.Error()) { + return true + } + return ctx != nil && errors.Is(ctx.Err(), context.Canceled) +} + func (al *AgentLoop) publishPlannedTaskProgress(msg bus.InboundMessage, total, completed, failed int, res plannedTaskResult) { if al == nil || al.bus == nil || total <= 1 { return diff --git a/pkg/agent/session_planner_progress_test.go b/pkg/agent/session_planner_progress_test.go index 3567580..41bd2e1 100644 --- a/pkg/agent/session_planner_progress_test.go +++ b/pkg/agent/session_planner_progress_test.go @@ -1,6 +1,10 @@ package agent -import "testing" +import ( + "context" + "errors" + "testing" +) func TestPlannedProgressMilestones(t *testing.T) { t.Parallel() @@ -16,20 +20,44 @@ func TestShouldPublishPlannedTaskProgress(t *testing.T) { milestones := plannedProgressMilestones(12) notified := map[int]struct{}{} - if shouldPublishPlannedTaskProgress(12, 1, plannedTaskResult{}, milestones, notified) { + if shouldPublishPlannedTaskProgress(context.Background(), 12, 1, plannedTaskResult{}, milestones, notified) { t.Fatalf("did not expect early success notification") } - if !shouldPublishPlannedTaskProgress(12, 4, plannedTaskResult{}, milestones, notified) { + if !shouldPublishPlannedTaskProgress(context.Background(), 12, 4, plannedTaskResult{}, milestones, notified) { t.Fatalf("expected milestone notification") } notified[4] = struct{}{} - if shouldPublishPlannedTaskProgress(12, 4, plannedTaskResult{}, milestones, notified) { + if shouldPublishPlannedTaskProgress(context.Background(), 12, 4, plannedTaskResult{}, milestones, notified) { t.Fatalf("did not expect duplicate milestone notification") } - if !shouldPublishPlannedTaskProgress(12, 5, plannedTaskResult{ErrText: "boom"}, milestones, notified) { + if !shouldPublishPlannedTaskProgress(context.Background(), 12, 5, plannedTaskResult{ErrText: "boom"}, milestones, notified) { t.Fatalf("expected failure notification") } - if shouldPublishPlannedTaskProgress(3, 3, plannedTaskResult{}, plannedProgressMilestones(3), map[int]struct{}{}) { + if shouldPublishPlannedTaskProgress(context.Background(), 3, 3, plannedTaskResult{}, plannedProgressMilestones(3), map[int]struct{}{}) { t.Fatalf("did not expect final success notification") } + if shouldPublishPlannedTaskProgress(context.Background(), 12, 5, plannedTaskResult{Err: context.Canceled, ErrText: context.Canceled.Error()}, milestones, notified) { + t.Fatalf("did not expect cancellation notification") + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if shouldPublishPlannedTaskProgress(ctx, 12, 5, plannedTaskResult{Err: errors.New("worker exited after parent stop"), ErrText: "worker exited after parent stop"}, milestones, notified) { + t.Fatalf("did not expect notification after parent cancellation") + } +} + +func TestIsPlannedTaskCancellation(t *testing.T) { + t.Parallel() + + if !isPlannedTaskCancellation(context.Background(), plannedTaskResult{Err: context.Canceled, ErrText: context.Canceled.Error()}) { + t.Fatalf("expected direct context cancellation to be detected") + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if !isPlannedTaskCancellation(ctx, plannedTaskResult{Err: errors.New("worker exited after parent stop"), ErrText: "worker exited after parent stop"}) { + t.Fatalf("expected canceled parent context to suppress planned task result") + } + if isPlannedTaskCancellation(context.Background(), plannedTaskResult{Err: errors.New("boom"), ErrText: "boom"}) { + t.Fatalf("did not expect non-cancellation error to be suppressed") + } }