From d1abd73e63e5df82b914cb5494f3a240bbf6d2de Mon Sep 17 00:00:00 2001 From: lpf Date: Mon, 9 Mar 2026 12:33:00 +0800 Subject: [PATCH] Optimize agent planning and subagent runtime --- pkg/agent/loop.go | 294 +++++++++++++++++++++++- pkg/agent/loop_system_notify_test.go | 91 ++++++++ pkg/agent/router_dispatch.go | 9 +- pkg/agent/runtime_admin.go | 29 +-- pkg/agent/session_planner.go | 133 ++++++++++- pkg/agent/session_planner_split_test.go | 80 ++++++- pkg/api/server.go | 238 ++++++++++++++++--- pkg/tools/subagent.go | 95 ++++++++ pkg/tools/subagent_router.go | 95 ++++---- pkg/tools/subagent_router_test.go | 26 +++ pkg/tools/task_watchdog.go | 2 + 11 files changed, 984 insertions(+), 108 deletions(-) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 7f4154e..3cec7e2 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -67,6 +67,24 @@ type AgentLoop struct { subagentConfigTool *tools.SubagentConfigTool nodeRouter *nodes.Router configPath string + subagentDigestMu sync.Mutex + subagentDigestDelay time.Duration + subagentDigests map[string]*subagentDigestState +} + +type subagentDigestItem struct { + agentID string + reason string + status string + taskSummary string + resultSummary string +} + +type subagentDigestState struct { + channel string + chatID string + items map[string]subagentDigestItem + dueAt time.Time } func (al *AgentLoop) SetConfigPath(path string) { @@ -320,7 +338,10 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers subagentRouter: subagentRouter, subagentConfigTool: subagentConfigTool, nodeRouter: nodesRouter, + subagentDigestDelay: 5 * time.Second, + subagentDigests: map[string]*subagentDigestState{}, } + go loop.runSubagentDigestTicker() // Initialize provider fallback chain (primary + proxy_fallbacks). loop.providerPool = map[string]providers.LLMProvider{} loop.providerNames = []string{} @@ -371,7 +392,31 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers sessionKey = fmt.Sprintf("subagent:%s", strings.TrimSpace(task.ID)) } taskInput := loop.buildSubagentTaskInput(task) - return loop.ProcessDirectWithOptions(ctx, taskInput, sessionKey, task.OriginChannel, task.OriginChatID, task.MemoryNS, task.ToolAllowlist) + ns := normalizeMemoryNamespace(task.MemoryNS) + ctx = withMemoryNamespaceContext(ctx, ns) + ctx = withToolAllowlistContext(ctx, task.ToolAllowlist) + channel := strings.TrimSpace(task.OriginChannel) + if channel == "" { + channel = "cli" + } + chatID := strings.TrimSpace(task.OriginChatID) + if chatID == "" { + chatID = "direct" + } + msg := bus.InboundMessage{ + Channel: channel, + SenderID: "subagent", + ChatID: chatID, + Content: taskInput, + SessionKey: sessionKey, + Metadata: map[string]string{ + "memory_namespace": ns, + "memory_ns": ns, + "disable_planning": "true", + "trigger": "subagent", + }, + } + return loop.processMessage(ctx, msg) }) return loop @@ -1349,6 +1394,10 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe "chat_id": msg.ChatID, }) + if al.handleSubagentSystemMessage(msg) { + return "", nil + } + originChannel, originChatID := resolveSystemOrigin(msg.ChatID) // Use the origin session for context @@ -1491,6 +1540,36 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe return finalContent, nil } +func (al *AgentLoop) handleSubagentSystemMessage(msg bus.InboundMessage) bool { + if !isSubagentSystemMessage(msg) { + return false + } + reason := "" + agentID := "" + status := "" + taskSummary := "" + resultSummary := "" + if msg.Metadata != nil { + reason = strings.ToLower(strings.TrimSpace(msg.Metadata["notify_reason"])) + agentID = strings.TrimSpace(msg.Metadata["agent_id"]) + status = strings.TrimSpace(msg.Metadata["status"]) + } + if agentID == "" { + agentID = strings.TrimSpace(strings.TrimPrefix(msg.SenderID, "subagent:")) + } + if taskSummary == "" || resultSummary == "" { + taskSummary, resultSummary = parseSubagentSystemContent(msg.Content) + } + al.enqueueSubagentDigest(msg, subagentDigestItem{ + agentID: agentID, + reason: reason, + status: status, + taskSummary: taskSummary, + resultSummary: resultSummary, + }) + return true +} + // truncate returns a truncated version of s with at most maxLen characters. // If the string is truncated, "..." is appended to indicate truncation. // If the string fits within maxLen, it is returned unchanged. @@ -2219,6 +2298,219 @@ func fallbackSubagentNotification(msg bus.InboundMessage) (string, bool) { return content, true } +func parseSubagentSystemContent(content string) (string, string) { + content = strings.TrimSpace(content) + if content == "" { + return "", "" + } + lines := strings.Split(content, "\n") + taskSummary := "" + resultSummary := "" + for _, line := range lines { + t := strings.TrimSpace(line) + if t == "" { + continue + } + lower := strings.ToLower(t) + switch { + case strings.HasPrefix(lower, "task:"): + taskSummary = strings.TrimSpace(t[len("task:"):]) + case strings.HasPrefix(lower, "summary:"): + resultSummary = strings.TrimSpace(t[len("summary:"):]) + } + } + if taskSummary == "" && len(lines) > 0 { + taskSummary = summarizeSystemNotificationText(content, 120) + } + return taskSummary, resultSummary +} + +func summarizeSystemNotificationText(s string, max int) string { + s = strings.TrimSpace(strings.ReplaceAll(s, "\r\n", "\n")) + s = strings.ReplaceAll(s, "\n", " ") + s = strings.Join(strings.Fields(s), " ") + if s == "" { + return "" + } + if max > 0 && len(s) > max { + return strings.TrimSpace(s[:max-3]) + "..." + } + return s +} + +func (al *AgentLoop) enqueueSubagentDigest(msg bus.InboundMessage, item subagentDigestItem) { + if al == nil || al.bus == nil { + return + } + originChannel, originChatID := resolveSystemOrigin(msg.ChatID) + key := originChannel + "\x00" + originChatID + delay := al.subagentDigestDelay + if delay <= 0 { + delay = 5 * time.Second + } + + al.subagentDigestMu.Lock() + state, ok := al.subagentDigests[key] + if !ok || state == nil { + state = &subagentDigestState{ + channel: originChannel, + chatID: originChatID, + items: map[string]subagentDigestItem{}, + } + al.subagentDigests[key] = state + } + itemKey := subagentDigestItemKey(item) + state.items[itemKey] = item + state.dueAt = time.Now().Add(delay) + al.subagentDigestMu.Unlock() +} + +func subagentDigestItemKey(item subagentDigestItem) string { + agentID := strings.ToLower(strings.TrimSpace(item.agentID)) + reason := strings.ToLower(strings.TrimSpace(item.reason)) + task := strings.ToLower(strings.TrimSpace(item.taskSummary)) + if agentID == "" { + agentID = "subagent" + } + return agentID + "\x00" + reason + "\x00" + task +} + +func (al *AgentLoop) flushSubagentDigest(key string) { + if al == nil || al.bus == nil { + return + } + al.subagentDigestMu.Lock() + state := al.subagentDigests[key] + delete(al.subagentDigests, key) + al.subagentDigestMu.Unlock() + if state == nil || len(state.items) == 0 { + return + } + content := formatSubagentDigestSummary(state.items) + if strings.TrimSpace(content) == "" { + return + } + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: state.channel, + ChatID: state.chatID, + Content: content, + }) +} + +func (al *AgentLoop) runSubagentDigestTicker() { + if al == nil { + return + } + tick := tools.GlobalWatchdogTick + if tick <= 0 { + tick = time.Second + } + ticker := time.NewTicker(tick) + defer ticker.Stop() + for now := range ticker.C { + al.flushDueSubagentDigests(now) + } +} + +func (al *AgentLoop) flushDueSubagentDigests(now time.Time) { + if al == nil || al.bus == nil { + return + } + dueKeys := make([]string, 0, 4) + al.subagentDigestMu.Lock() + for key, state := range al.subagentDigests { + if state == nil || state.dueAt.IsZero() || now.Before(state.dueAt) { + continue + } + dueKeys = append(dueKeys, key) + } + al.subagentDigestMu.Unlock() + for _, key := range dueKeys { + al.flushSubagentDigest(key) + } +} + +func formatSubagentDigestSummary(items map[string]subagentDigestItem) string { + if len(items) == 0 { + return "" + } + list := make([]subagentDigestItem, 0, len(items)) + completed := 0 + blocked := 0 + milestone := 0 + failed := 0 + for _, item := range items { + list = append(list, item) + switch { + case strings.EqualFold(strings.TrimSpace(item.reason), "blocked"): + blocked++ + case strings.EqualFold(strings.TrimSpace(item.reason), "milestone"): + milestone++ + case strings.EqualFold(strings.TrimSpace(item.status), "failed"): + failed++ + default: + completed++ + } + } + sort.Slice(list, func(i, j int) bool { + left := strings.TrimSpace(list[i].agentID) + right := strings.TrimSpace(list[j].agentID) + if left == right { + return strings.TrimSpace(list[i].taskSummary) < strings.TrimSpace(list[j].taskSummary) + } + return left < right + }) + var sb strings.Builder + sb.WriteString("阶段总结") + stats := make([]string, 0, 4) + if completed > 0 { + stats = append(stats, fmt.Sprintf("完成 %d", completed)) + } + if blocked > 0 { + stats = append(stats, fmt.Sprintf("受阻 %d", blocked)) + } + if failed > 0 { + stats = append(stats, fmt.Sprintf("失败 %d", failed)) + } + if milestone > 0 { + stats = append(stats, fmt.Sprintf("进展 %d", milestone)) + } + if len(stats) > 0 { + sb.WriteString("(" + strings.Join(stats, ",") + ")") + } + sb.WriteString("\n") + for _, item := range list { + agentLabel := strings.TrimSpace(item.agentID) + if agentLabel == "" { + agentLabel = "subagent" + } + statusText := "已完成" + switch { + case strings.EqualFold(strings.TrimSpace(item.reason), "blocked"): + statusText = "受阻" + case strings.EqualFold(strings.TrimSpace(item.reason), "milestone"): + statusText = "有进展" + case strings.EqualFold(strings.TrimSpace(item.status), "failed"): + statusText = "失败" + } + sb.WriteString("- " + agentLabel + ":" + statusText) + if task := strings.TrimSpace(item.taskSummary); task != "" { + sb.WriteString(",任务:" + task) + } + if summary := strings.TrimSpace(item.resultSummary); summary != "" { + label := "摘要" + if statusText == "受阻" { + label = "原因" + } else if statusText == "有进展" { + label = "进度" + } + sb.WriteString("," + label + ":" + summary) + } + sb.WriteString("\n") + } + return strings.TrimSpace(sb.String()) +} + func shouldFlushTelegramStreamSnapshot(s string) bool { s = strings.TrimRight(s, " \t") if s == "" { diff --git a/pkg/agent/loop_system_notify_test.go b/pkg/agent/loop_system_notify_test.go index d052a1d..ec845f9 100644 --- a/pkg/agent/loop_system_notify_test.go +++ b/pkg/agent/loop_system_notify_test.go @@ -1,8 +1,10 @@ package agent import ( + "context" "strings" "testing" + "time" "clawgo/pkg/bus" ) @@ -66,3 +68,92 @@ func TestPrepareOutboundSubagentNoReplyFallbackWithMissingOrigin(t *testing.T) { t.Fatalf("unexpected fallback content: %q", outbound.Content) } } + +func TestProcessSystemMessageSubagentBlockedQueuedIntoDigest(t *testing.T) { + msgBus := bus.NewMessageBus() + al := &AgentLoop{ + bus: msgBus, + subagentDigestDelay: 10 * time.Millisecond, + subagentDigests: map[string]*subagentDigestState{}, + } + out, err := al.processSystemMessage(context.Background(), bus.InboundMessage{ + Channel: "system", + SenderID: "subagent:subagent-3", + ChatID: "telegram:9527", + Content: "Subagent update\nagent: coder\nrun: subagent-3\nstatus: blocked\nreason: blocked\ntask: 修复登录\nsummary: rate limit", + Metadata: map[string]string{ + "trigger": "subagent", + "agent_id": "coder", + "status": "failed", + "notify_reason": "blocked", + }, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if out != "" { + t.Fatalf("expected queued digest with no immediate output, got %q", out) + } + al.flushDueSubagentDigests(time.Now().Add(20 * time.Millisecond)) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + outbound, ok := msgBus.SubscribeOutbound(ctx) + if !ok { + t.Fatalf("expected outbound digest") + } + if !strings.Contains(outbound.Content, "阶段总结") || !strings.Contains(outbound.Content, "coder") || !strings.Contains(outbound.Content, "受阻") { + t.Fatalf("unexpected digest content: %q", outbound.Content) + } +} + +func TestProcessSystemMessageSubagentDigestMergesMultipleUpdates(t *testing.T) { + msgBus := bus.NewMessageBus() + al := &AgentLoop{ + bus: msgBus, + subagentDigestDelay: 10 * time.Millisecond, + subagentDigests: map[string]*subagentDigestState{}, + } + first := bus.InboundMessage{ + Channel: "system", + SenderID: "subagent:subagent-7", + ChatID: "telegram:9527", + Content: "Subagent update\nagent: tester\nrun: subagent-7\nstatus: completed\nreason: final\ntask: 回归测试\nsummary: 所有测试通过", + Metadata: map[string]string{ + "trigger": "subagent", + "agent_id": "tester", + "status": "completed", + "notify_reason": "final", + }, + } + second := bus.InboundMessage{ + Channel: "system", + SenderID: "subagent:subagent-8", + ChatID: "telegram:9527", + Content: "Subagent update\nagent: coder\nrun: subagent-8\nstatus: completed\nreason: final\ntask: 修复登录\nsummary: 接口已联调", + Metadata: map[string]string{ + "trigger": "subagent", + "agent_id": "coder", + "status": "completed", + "notify_reason": "final", + }, + } + if out, err := al.processSystemMessage(context.Background(), first); err != nil || out != "" { + t.Fatalf("unexpected first result out=%q err=%v", out, err) + } + if out, err := al.processSystemMessage(context.Background(), second); err != nil || out != "" { + t.Fatalf("unexpected second result out=%q err=%v", out, err) + } + al.flushDueSubagentDigests(time.Now().Add(20 * time.Millisecond)) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + outbound, ok := msgBus.SubscribeOutbound(ctx) + if !ok { + t.Fatalf("expected merged outbound digest") + } + if strings.Count(outbound.Content, "\n- ") != 2 { + t.Fatalf("expected two digest lines, got: %q", outbound.Content) + } + if !strings.Contains(outbound.Content, "完成 2") { + t.Fatalf("expected aggregate completion count, got: %q", outbound.Content) + } +} diff --git a/pkg/agent/router_dispatch.go b/pkg/agent/router_dispatch.go index 76d58c7..c755d99 100644 --- a/pkg/agent/router_dispatch.go +++ b/pkg/agent/router_dispatch.go @@ -38,10 +38,11 @@ func (al *AgentLoop) maybeAutoRoute(ctx context.Context, msg bus.InboundMessage) waitCtx, cancel := context.WithTimeout(ctx, time.Duration(waitTimeout)*time.Second) defer cancel() task, err := al.subagentRouter.DispatchTask(waitCtx, tools.RouterDispatchRequest{ - Task: taskText, - AgentID: agentID, - OriginChannel: msg.Channel, - OriginChatID: msg.ChatID, + Task: taskText, + AgentID: agentID, + NotifyMainPolicy: "internal_only", + OriginChannel: msg.Channel, + OriginChatID: msg.ChatID, }) if err != nil { return "", true, err diff --git a/pkg/agent/runtime_admin.go b/pkg/agent/runtime_admin.go index 5b59e89..dee2387 100644 --- a/pkg/agent/runtime_admin.go +++ b/pkg/agent/runtime_admin.go @@ -76,20 +76,21 @@ func (al *AgentLoop) HandleSubagentRuntime(ctx context.Context, action string, a return nil, fmt.Errorf("task is required") } task, err := router.DispatchTask(ctx, tools.RouterDispatchRequest{ - Task: taskInput, - Label: runtimeStringArg(args, "label"), - Role: runtimeStringArg(args, "role"), - AgentID: runtimeStringArg(args, "agent_id"), - ThreadID: runtimeStringArg(args, "thread_id"), - CorrelationID: runtimeStringArg(args, "correlation_id"), - ParentRunID: runtimeStringArg(args, "parent_run_id"), - OriginChannel: fallbackString(runtimeStringArg(args, "channel"), "webui"), - OriginChatID: fallbackString(runtimeStringArg(args, "chat_id"), "webui"), - MaxRetries: runtimeIntArg(args, "max_retries", 0), - RetryBackoff: runtimeIntArg(args, "retry_backoff_ms", 0), - TimeoutSec: runtimeIntArg(args, "timeout_sec", 0), - MaxTaskChars: runtimeIntArg(args, "max_task_chars", 0), - MaxResultChars: runtimeIntArg(args, "max_result_chars", 0), + Task: taskInput, + Label: runtimeStringArg(args, "label"), + Role: runtimeStringArg(args, "role"), + AgentID: runtimeStringArg(args, "agent_id"), + NotifyMainPolicy: "internal_only", + ThreadID: runtimeStringArg(args, "thread_id"), + CorrelationID: runtimeStringArg(args, "correlation_id"), + ParentRunID: runtimeStringArg(args, "parent_run_id"), + OriginChannel: fallbackString(runtimeStringArg(args, "channel"), "webui"), + OriginChatID: fallbackString(runtimeStringArg(args, "chat_id"), "webui"), + MaxRetries: runtimeIntArg(args, "max_retries", 0), + RetryBackoff: runtimeIntArg(args, "retry_backoff_ms", 0), + TimeoutSec: runtimeIntArg(args, "timeout_sec", 0), + MaxTaskChars: runtimeIntArg(args, "max_task_chars", 0), + MaxResultChars: runtimeIntArg(args, "max_result_chars", 0), }) if err != nil { return nil, err diff --git a/pkg/agent/session_planner.go b/pkg/agent/session_planner.go index 91d4f29..21a1f39 100644 --- a/pkg/agent/session_planner.go +++ b/pkg/agent/session_planner.go @@ -12,9 +12,11 @@ import ( "regexp" "strings" "sync" + "time" "clawgo/pkg/bus" "clawgo/pkg/ekg" + "clawgo/pkg/providers" "clawgo/pkg/scheduling" ) @@ -35,18 +37,23 @@ type plannedTaskResult struct { var reLeadingNumber = regexp.MustCompile(`^\d+[\.)、]\s*`) func (al *AgentLoop) processPlannedMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { - tasks := al.planSessionTasks(msg) + tasks := al.planSessionTasks(ctx, msg) if len(tasks) <= 1 { return al.processMessage(ctx, msg) } return al.runPlannedTasks(ctx, msg, tasks) } -func (al *AgentLoop) planSessionTasks(msg bus.InboundMessage) []plannedTask { +func (al *AgentLoop) planSessionTasks(ctx context.Context, msg bus.InboundMessage) []plannedTask { base := strings.TrimSpace(msg.Content) if base == "" { return nil } + if msg.Metadata != nil { + if planningDisabled(msg.Metadata["disable_planning"]) { + return []plannedTask{{Index: 1, Content: base, ResourceKeys: scheduling.DeriveResourceKeys(base)}} + } + } if msg.Channel == "system" || msg.Channel == "internal" { return []plannedTask{{Index: 1, Content: base, ResourceKeys: scheduling.DeriveResourceKeys(base)}} } @@ -63,6 +70,9 @@ func (al *AgentLoop) planSessionTasks(msg bus.InboundMessage) []plannedTask { if len(segments) <= 1 { return []plannedTask{{Index: 1, Content: base, ResourceKeys: scheduling.DeriveResourceKeys(base)}} } + if refined, ok := al.inferPlannedSegments(ctx, base, segments); ok { + segments = refined + } out := make([]plannedTask, 0, len(segments)) for i, seg := range segments { @@ -86,6 +96,125 @@ func (al *AgentLoop) planSessionTasks(msg bus.InboundMessage) []plannedTask { return out } +type plannerDecision struct { + ShouldSplit bool `json:"should_split"` + Tasks []string `json:"tasks"` +} + +func (al *AgentLoop) inferPlannedSegments(ctx context.Context, content string, candidates []string) ([]string, bool) { + if al == nil || al.provider == nil { + return nil, false + } + content = strings.TrimSpace(content) + if content == "" || len(candidates) <= 1 { + return nil, false + } + plannerCtx := ctx + if plannerCtx == nil { + plannerCtx = context.Background() + } + var cancel context.CancelFunc + plannerCtx, cancel = context.WithTimeout(plannerCtx, 8*time.Second) + defer cancel() + + previewCandidates := candidates + if len(previewCandidates) > 12 { + previewCandidates = previewCandidates[:12] + } + var candidateList strings.Builder + for i, item := range previewCandidates { + candidateList.WriteString(fmt.Sprintf("%d. %s\n", i+1, strings.TrimSpace(item))) + } + + resp, err := al.provider.Chat(plannerCtx, []providers.Message{ + { + Role: "system", + Content: "Decide whether the request should stay as one task or be split into a small number of high-level task groups. " + + "Default to no split. Only split when the request contains multiple independent deliverables, roles, or workstreams. " + + "Never split into fine-grained execution steps. Merge related steps. Return strict JSON only: " + + `{"should_split":true|false,"tasks":["..."]}` + + " If should_split is false, tasks must be empty. If true, tasks must contain 2 to 8 concise high-level task groups.", + }, + { + Role: "user", + Content: fmt.Sprintf("Original request:\n%s\n\nRule-based candidate segments (%d shown of %d):\n%s", + content, + len(previewCandidates), + len(candidates), + strings.TrimSpace(candidateList.String()), + ), + }, + }, nil, al.provider.GetDefaultModel(), map[string]interface{}{ + "max_tokens": 256, + }) + if err != nil || resp == nil { + return nil, false + } + decision, ok := parsePlannerDecision(resp.Content) + if !ok { + return nil, false + } + if !decision.ShouldSplit { + return []string{content}, true + } + tasks := sanitizePlannerTasks(decision.Tasks) + if len(tasks) < 2 || len(tasks) > 8 { + return []string{content}, true + } + return tasks, true +} + +func parsePlannerDecision(raw string) (plannerDecision, bool) { + raw = strings.TrimSpace(raw) + if raw == "" { + return plannerDecision{}, false + } + if fenced := extractJSONObject(raw); fenced != "" { + raw = fenced + } + var out plannerDecision + if err := json.Unmarshal([]byte(raw), &out); err != nil { + return plannerDecision{}, false + } + return out, true +} + +func extractJSONObject(raw string) string { + start := strings.Index(raw, "{") + end := strings.LastIndex(raw, "}") + if start < 0 || end <= start { + return "" + } + return strings.TrimSpace(raw[start : end+1]) +} + +func sanitizePlannerTasks(items []string) []string { + out := make([]string, 0, len(items)) + seen := map[string]struct{}{} + for _, item := range items { + t := strings.TrimSpace(reLeadingNumber.ReplaceAllString(strings.TrimSpace(item), "")) + if t == "" { + continue + } + key := strings.ToLower(t) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + out = append(out, t) + } + return out +} + +func planningDisabled(v string) bool { + switch strings.ToLower(strings.TrimSpace(v)) { + case "1", "true", "yes", "on": + return true + default: + return false + } +} + func splitPlannedSegments(content string) []string { lines := strings.Split(content, "\n") bullet := make([]string, 0, len(lines)) diff --git a/pkg/agent/session_planner_split_test.go b/pkg/agent/session_planner_split_test.go index 30382d0..9e0606a 100644 --- a/pkg/agent/session_planner_split_test.go +++ b/pkg/agent/session_planner_split_test.go @@ -1,6 +1,12 @@ package agent -import "testing" +import ( + "context" + "testing" + + "clawgo/pkg/bus" + "clawgo/pkg/providers" +) func TestSplitPlannedSegmentsDoesNotSplitPlainNewlines(t *testing.T) { t.Parallel() @@ -31,3 +37,75 @@ func TestSplitPlannedSegmentsStillSplitsSemicolons(t *testing.T) { t.Fatalf("expected 2 segments, got %d: %#v", len(got), got) } } + +func TestPlanSessionTasksDisablePlanningMetadataSkipsBulletSplit(t *testing.T) { + t.Parallel() + + al := &AgentLoop{} + msg := bus.InboundMessage{ + Content: "Role Profile Policy:\n1. 先做代码审查\n2. 再执行测试\n\nTask:\n修复登录流程", + Metadata: map[string]string{ + "disable_planning": "true", + }, + } + + got := al.planSessionTasks(context.Background(), msg) + if len(got) != 1 { + t.Fatalf("expected 1 segment, got %d: %#v", len(got), got) + } + if got[0].Content != msg.Content { + t.Fatalf("expected original content to be preserved, got: %#v", got[0].Content) + } +} + +func TestPlanSessionTasksUsesAIPlannerDecisionToAvoidOversplitting(t *testing.T) { + t.Parallel() + + al := &AgentLoop{ + provider: plannerStubProvider{content: `{"should_split":false,"tasks":[]}`}, + } + msg := bus.InboundMessage{ + Content: "1. 调研\n2. 写方案\n3. 设计数据库\n4. 写接口\n5. 写前端\n6. 写测试", + } + + got := al.planSessionTasks(context.Background(), msg) + if len(got) != 1 { + t.Fatalf("expected AI planner to keep one task, got %d: %#v", len(got), got) + } + if got[0].Content != msg.Content { + t.Fatalf("expected original content, got %#v", got[0].Content) + } +} + +func TestPlanSessionTasksUsesAIPlannerDecisionToSplitIntoHighLevelGroups(t *testing.T) { + t.Parallel() + + al := &AgentLoop{ + provider: plannerStubProvider{content: `{"should_split":true,"tasks":["产品方案","研发实现","测试验收"]}`}, + } + msg := bus.InboundMessage{ + Content: "1. 调研\n2. 写方案\n3. 设计数据库\n4. 写接口\n5. 写前端\n6. 写测试", + } + + got := al.planSessionTasks(context.Background(), msg) + if len(got) != 3 { + t.Fatalf("expected 3 planned tasks, got %d: %#v", len(got), got) + } + if got[0].Content != "产品方案" || got[1].Content != "研发实现" || got[2].Content != "测试验收" { + t.Fatalf("unexpected planned task contents: %#v", got) + } +} + +type plannerStubProvider struct { + content string + err error +} + +func (p plannerStubProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + if p.err != nil { + return nil, p.err + } + return &providers.LLMResponse{Content: p.content}, nil +} + +func (p plannerStubProvider) GetDefaultModel() string { return "stub" } diff --git a/pkg/api/server.go b/pkg/api/server.go index c3a14e5..3fa9bef 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -64,6 +64,11 @@ type Server struct { ekgCacheStamp time.Time ekgCacheSize int64 ekgCacheRows []map[string]interface{} + liveRuntimeMu sync.Mutex + liveRuntimeSubs map[chan []byte]struct{} + liveRuntimeOn bool + liveSubagentMu sync.Mutex + liveSubagents map[string]*liveSubagentGroup } var nodesWebsocketUpgrader = websocket.Upgrader{ @@ -79,12 +84,14 @@ func NewServer(host string, port int, token string, mgr *nodes.Manager) *Server port = 7788 } return &Server{ - addr: fmt.Sprintf("%s:%d", addr, port), - token: strings.TrimSpace(token), - mgr: mgr, - nodeConnIDs: map[string]string{}, - nodeSockets: map[string]*nodeSocketConn{}, - artifactStats: map[string]interface{}{}, + addr: fmt.Sprintf("%s:%d", addr, port), + token: strings.TrimSpace(token), + mgr: mgr, + nodeConnIDs: map[string]string{}, + nodeSockets: map[string]*nodeSocketConn{}, + artifactStats: map[string]interface{}{}, + liveRuntimeSubs: map[chan []byte]struct{}{}, + liveSubagents: map[string]*liveSubagentGroup{}, } } @@ -94,6 +101,13 @@ type nodeSocketConn struct { mu sync.Mutex } +type liveSubagentGroup struct { + taskID string + previewTaskID string + subs map[chan []byte]struct{} + stopCh chan struct{} +} + func (c *nodeSocketConn) Send(msg nodes.WireMessage) error { if c == nil || c.conn == nil { return fmt.Errorf("node websocket unavailable") @@ -104,6 +118,168 @@ func (c *nodeSocketConn) Send(msg nodes.WireMessage) error { return c.conn.WriteJSON(msg) } +func publishLiveSnapshot(subs map[chan []byte]struct{}, payload []byte) { + for ch := range subs { + select { + case ch <- payload: + default: + select { + case <-ch: + default: + } + select { + case ch <- payload: + default: + } + } + } +} + +func (s *Server) subscribeRuntimeLive(ctx context.Context) chan []byte { + ch := make(chan []byte, 1) + s.liveRuntimeMu.Lock() + s.liveRuntimeSubs[ch] = struct{}{} + start := !s.liveRuntimeOn + if start { + s.liveRuntimeOn = true + } + s.liveRuntimeMu.Unlock() + if start { + go s.runtimeLiveLoop() + } + go func() { + <-ctx.Done() + s.unsubscribeRuntimeLive(ch) + }() + return ch +} + +func (s *Server) unsubscribeRuntimeLive(ch chan []byte) { + s.liveRuntimeMu.Lock() + delete(s.liveRuntimeSubs, ch) + s.liveRuntimeMu.Unlock() +} + +func (s *Server) runtimeLiveLoop() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + for { + if !s.publishRuntimeSnapshot(context.Background()) { + s.liveRuntimeMu.Lock() + if len(s.liveRuntimeSubs) == 0 { + s.liveRuntimeOn = false + s.liveRuntimeMu.Unlock() + return + } + s.liveRuntimeMu.Unlock() + } + <-ticker.C + } +} + +func (s *Server) publishRuntimeSnapshot(ctx context.Context) bool { + if s == nil { + return false + } + payload := map[string]interface{}{ + "ok": true, + "type": "runtime_snapshot", + "snapshot": s.buildWebUIRuntimeSnapshot(ctx), + } + data, err := json.Marshal(payload) + if err != nil { + return false + } + s.liveRuntimeMu.Lock() + defer s.liveRuntimeMu.Unlock() + if len(s.liveRuntimeSubs) == 0 { + return false + } + publishLiveSnapshot(s.liveRuntimeSubs, data) + return true +} + +func buildSubagentLiveKey(taskID, previewTaskID string) string { + return strings.TrimSpace(taskID) + "\x00" + strings.TrimSpace(previewTaskID) +} + +func (s *Server) subscribeSubagentLive(ctx context.Context, taskID, previewTaskID string) chan []byte { + ch := make(chan []byte, 1) + key := buildSubagentLiveKey(taskID, previewTaskID) + s.liveSubagentMu.Lock() + group := s.liveSubagents[key] + if group == nil { + group = &liveSubagentGroup{ + taskID: strings.TrimSpace(taskID), + previewTaskID: strings.TrimSpace(previewTaskID), + subs: map[chan []byte]struct{}{}, + stopCh: make(chan struct{}), + } + s.liveSubagents[key] = group + go s.subagentLiveLoop(key, group) + } + group.subs[ch] = struct{}{} + s.liveSubagentMu.Unlock() + go func() { + <-ctx.Done() + s.unsubscribeSubagentLive(key, ch) + }() + return ch +} + +func (s *Server) unsubscribeSubagentLive(key string, ch chan []byte) { + s.liveSubagentMu.Lock() + group := s.liveSubagents[key] + if group == nil { + s.liveSubagentMu.Unlock() + return + } + delete(group.subs, ch) + if len(group.subs) == 0 { + delete(s.liveSubagents, key) + close(group.stopCh) + } + s.liveSubagentMu.Unlock() +} + +func (s *Server) subagentLiveLoop(key string, group *liveSubagentGroup) { + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + if !s.publishSubagentLiveSnapshot(context.Background(), key, group.taskID, group.previewTaskID) { + return + } + select { + case <-group.stopCh: + return + case <-ticker.C: + } + } +} + +func (s *Server) publishSubagentLiveSnapshot(ctx context.Context, key, taskID, previewTaskID string) bool { + if s == nil { + return false + } + payload := map[string]interface{}{ + "ok": true, + "type": "subagents_live", + "payload": s.buildSubagentsLivePayload(ctx, taskID, previewTaskID), + } + data, err := json.Marshal(payload) + if err != nil { + return false + } + s.liveSubagentMu.Lock() + defer s.liveSubagentMu.Unlock() + group := s.liveSubagents[key] + if group == nil || len(group.subs) == 0 { + return false + } + publishLiveSnapshot(group.subs, data) + return true +} + func (s *Server) SetConfigPath(path string) { s.configPath = strings.TrimSpace(path) } func (s *Server) SetWorkspacePath(path string) { s.workspacePath = strings.TrimSpace(path) } func (s *Server) SetLogFilePath(path string) { s.logFilePath = strings.TrimSpace(path) } @@ -977,28 +1153,23 @@ func (s *Server) handleWebUIRuntime(w http.ResponseWriter, r *http.Request) { defer conn.Close() ctx := r.Context() - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - - sendSnapshot := func() error { - payload := map[string]interface{}{ - "ok": true, - "type": "runtime_snapshot", - "snapshot": s.buildWebUIRuntimeSnapshot(ctx), - } - _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - return conn.WriteJSON(payload) + sub := s.subscribeRuntimeLive(ctx) + initial := map[string]interface{}{ + "ok": true, + "type": "runtime_snapshot", + "snapshot": s.buildWebUIRuntimeSnapshot(ctx), } - - if err := sendSnapshot(); err != nil { + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if err := conn.WriteJSON(initial); err != nil { return } for { select { case <-ctx.Done(): return - case <-ticker.C: - if err := sendSnapshot(); err != nil { + case payload := <-sub: + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if err := conn.WriteMessage(websocket.TextMessage, payload); err != nil { return } } @@ -4245,28 +4416,23 @@ func (s *Server) handleWebUISubagentsRuntimeLive(w http.ResponseWriter, r *http. ctx := r.Context() taskID := strings.TrimSpace(r.URL.Query().Get("task_id")) previewTaskID := strings.TrimSpace(r.URL.Query().Get("preview_task_id")) - ticker := time.NewTicker(2 * time.Second) - defer ticker.Stop() - - sendSnapshot := func() error { - payload := map[string]interface{}{ - "ok": true, - "type": "subagents_live", - "payload": s.buildSubagentsLivePayload(ctx, taskID, previewTaskID), - } - _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - return conn.WriteJSON(payload) + sub := s.subscribeSubagentLive(ctx, taskID, previewTaskID) + initial := map[string]interface{}{ + "ok": true, + "type": "subagents_live", + "payload": s.buildSubagentsLivePayload(ctx, taskID, previewTaskID), } - - if err := sendSnapshot(); err != nil { + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if err := conn.WriteJSON(initial); err != nil { return } for { select { case <-ctx.Done(): return - case <-ticker.C: - if err := sendSnapshot(); err != nil { + case payload := <-sub: + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if err := conn.WriteMessage(websocket.TextMessage, payload); err != nil { return } } diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 914db4b..dcada2a 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -53,6 +53,7 @@ type SubagentTask struct { type SubagentManager struct { tasks map[string]*SubagentTask cancelFuncs map[string]context.CancelFunc + waiters map[string]map[chan struct{}]struct{} recoverableTaskIDs []string archiveAfterMinute int64 mu sync.RWMutex @@ -92,6 +93,7 @@ func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *b mgr := &SubagentManager{ tasks: make(map[string]*SubagentTask), cancelFuncs: make(map[string]context.CancelFunc), + waiters: make(map[string]map[chan struct{}]struct{}), archiveAfterMinute: 60, provider: provider, bus: bus, @@ -356,6 +358,7 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { CreatedAt: task.Updated, }) sm.persistTaskLocked(task, "completed", task.Result) + sm.notifyTaskWaitersLocked(task.ID) } else { task.Status = "completed" task.Result = applySubagentResultQuota(result, task.MaxResultChars) @@ -373,6 +376,7 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { CreatedAt: task.Updated, }) sm.persistTaskLocked(task, "completed", task.Result) + sm.notifyTaskWaitersLocked(task.ID) } sm.mu.Unlock() @@ -790,6 +794,7 @@ func (sm *SubagentManager) KillTask(taskID string) bool { t.WaitingReply = false t.Updated = time.Now().UnixMilli() sm.persistTaskLocked(t, "killed", "") + sm.notifyTaskWaitersLocked(taskID) } return true } @@ -1013,6 +1018,96 @@ func (sm *SubagentManager) persistTaskLocked(task *SubagentTask, eventType, mess }) } +func (sm *SubagentManager) WaitTask(ctx context.Context, taskID string) (*SubagentTask, bool, error) { + if sm == nil { + return nil, false, fmt.Errorf("subagent manager not available") + } + taskID = strings.TrimSpace(taskID) + if taskID == "" { + return nil, false, fmt.Errorf("task id is required") + } + if ctx == nil { + ctx = context.Background() + } + ch := make(chan struct{}, 1) + sm.mu.Lock() + sm.pruneArchivedLocked() + task, ok := sm.tasks[taskID] + if !ok && sm.runStore != nil { + if persisted, found := sm.runStore.Get(taskID); found && persisted != nil { + if strings.TrimSpace(persisted.Status) != "running" { + sm.mu.Unlock() + return persisted, true, nil + } + } + } + if ok && task != nil && strings.TrimSpace(task.Status) != "running" { + cp := cloneSubagentTask(task) + sm.mu.Unlock() + return cp, true, nil + } + waiters := sm.waiters[taskID] + if waiters == nil { + waiters = map[chan struct{}]struct{}{} + sm.waiters[taskID] = waiters + } + waiters[ch] = struct{}{} + sm.mu.Unlock() + + defer sm.removeTaskWaiter(taskID, ch) + for { + select { + case <-ctx.Done(): + return nil, false, ctx.Err() + case <-ch: + sm.mu.Lock() + sm.pruneArchivedLocked() + task, ok := sm.tasks[taskID] + if ok && task != nil && strings.TrimSpace(task.Status) != "running" { + cp := cloneSubagentTask(task) + sm.mu.Unlock() + return cp, true, nil + } + if !ok && sm.runStore != nil { + if persisted, found := sm.runStore.Get(taskID); found && persisted != nil && strings.TrimSpace(persisted.Status) != "running" { + sm.mu.Unlock() + return persisted, true, nil + } + } + sm.mu.Unlock() + } + } +} + +func (sm *SubagentManager) removeTaskWaiter(taskID string, ch chan struct{}) { + sm.mu.Lock() + defer sm.mu.Unlock() + waiters := sm.waiters[taskID] + if len(waiters) == 0 { + delete(sm.waiters, taskID) + return + } + delete(waiters, ch) + if len(waiters) == 0 { + delete(sm.waiters, taskID) + } +} + +func (sm *SubagentManager) notifyTaskWaitersLocked(taskID string) { + waiters := sm.waiters[taskID] + if len(waiters) == 0 { + delete(sm.waiters, taskID) + return + } + for ch := range waiters { + select { + case ch <- struct{}{}: + default: + } + } + delete(sm.waiters, taskID) +} + func (sm *SubagentManager) recordMailboxMessageLocked(task *SubagentTask, msg AgentMessage) { if sm.mailboxStore == nil || task == nil { return diff --git a/pkg/tools/subagent_router.go b/pkg/tools/subagent_router.go index eaaccfb..c17f604 100644 --- a/pkg/tools/subagent_router.go +++ b/pkg/tools/subagent_router.go @@ -8,20 +8,21 @@ import ( ) type RouterDispatchRequest struct { - Task string - Label string - Role string - AgentID string - ThreadID string - CorrelationID string - ParentRunID string - OriginChannel string - OriginChatID string - MaxRetries int - RetryBackoff int - TimeoutSec int - MaxTaskChars int - MaxResultChars int + Task string + Label string + Role string + AgentID string + NotifyMainPolicy string + ThreadID string + CorrelationID string + ParentRunID string + OriginChannel string + OriginChatID string + MaxRetries int + RetryBackoff int + TimeoutSec int + MaxTaskChars int + MaxResultChars int } type RouterReply struct { @@ -46,20 +47,21 @@ func (r *SubagentRouter) DispatchTask(ctx context.Context, req RouterDispatchReq return nil, fmt.Errorf("subagent router is not configured") } task, err := r.manager.SpawnTask(ctx, SubagentSpawnOptions{ - Task: req.Task, - Label: req.Label, - Role: req.Role, - AgentID: req.AgentID, - ThreadID: req.ThreadID, - CorrelationID: req.CorrelationID, - ParentRunID: req.ParentRunID, - OriginChannel: req.OriginChannel, - OriginChatID: req.OriginChatID, - MaxRetries: req.MaxRetries, - RetryBackoff: req.RetryBackoff, - TimeoutSec: req.TimeoutSec, - MaxTaskChars: req.MaxTaskChars, - MaxResultChars: req.MaxResultChars, + Task: req.Task, + Label: req.Label, + Role: req.Role, + AgentID: req.AgentID, + NotifyMainPolicy: req.NotifyMainPolicy, + ThreadID: req.ThreadID, + CorrelationID: req.CorrelationID, + ParentRunID: req.ParentRunID, + OriginChannel: req.OriginChannel, + OriginChatID: req.OriginChatID, + MaxRetries: req.MaxRetries, + RetryBackoff: req.RetryBackoff, + TimeoutSec: req.TimeoutSec, + MaxTaskChars: req.MaxTaskChars, + MaxResultChars: req.MaxResultChars, }) if err != nil { return nil, err @@ -71,33 +73,26 @@ func (r *SubagentRouter) WaitReply(ctx context.Context, taskID string, interval if r == nil || r.manager == nil { return nil, fmt.Errorf("subagent router is not configured") } - if interval <= 0 { - interval = 100 * time.Millisecond - } + _ = interval taskID = strings.TrimSpace(taskID) if taskID == "" { return nil, fmt.Errorf("task id is required") } - ticker := time.NewTicker(interval) - defer ticker.Stop() - for { - task, ok := r.manager.GetTask(taskID) - if ok && task != nil && task.Status != "running" { - return &RouterReply{ - TaskID: task.ID, - ThreadID: task.ThreadID, - CorrelationID: task.CorrelationID, - AgentID: task.AgentID, - Status: task.Status, - Result: strings.TrimSpace(task.Result), - }, nil - } - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-ticker.C: - } + task, ok, err := r.manager.WaitTask(ctx, taskID) + if err != nil { + return nil, err } + if !ok || task == nil { + return nil, fmt.Errorf("subagent not found") + } + return &RouterReply{ + TaskID: task.ID, + ThreadID: task.ThreadID, + CorrelationID: task.CorrelationID, + AgentID: task.AgentID, + Status: task.Status, + Result: strings.TrimSpace(task.Result), + }, nil } func (r *SubagentRouter) MergeResults(replies []*RouterReply) string { diff --git a/pkg/tools/subagent_router_test.go b/pkg/tools/subagent_router_test.go index 5f7f757..cde693e 100644 --- a/pkg/tools/subagent_router_test.go +++ b/pkg/tools/subagent_router_test.go @@ -47,3 +47,29 @@ func TestSubagentRouterMergeResults(t *testing.T) { t.Fatalf("unexpected merged output: %s", out) } } + +func TestSubagentRouterWaitReplyContextCancel(t *testing.T) { + workspace := t.TempDir() + manager := NewSubagentManager(nil, workspace, nil) + manager.SetRunFunc(func(ctx context.Context, task *SubagentTask) (string, error) { + <-ctx.Done() + return "", ctx.Err() + }) + router := NewSubagentRouter(manager) + + task, err := router.DispatchTask(context.Background(), RouterDispatchRequest{ + Task: "long task", + AgentID: "coder", + OriginChannel: "cli", + OriginChatID: "direct", + }) + if err != nil { + t.Fatalf("dispatch failed: %v", err) + } + + waitCtx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + if _, err := router.WaitReply(waitCtx, task.ID, 20*time.Millisecond); err == nil { + t.Fatalf("expected context cancellation error") + } +} diff --git a/pkg/tools/task_watchdog.go b/pkg/tools/task_watchdog.go index 6075380..e422c01 100644 --- a/pkg/tools/task_watchdog.go +++ b/pkg/tools/task_watchdog.go @@ -25,6 +25,8 @@ const ( maxWorldCycle = 60 * time.Second ) +const GlobalWatchdogTick = watchdogTick + var ErrCommandNoProgress = errors.New("command no progress across tick rounds") var ErrTaskWatchdogTimeout = errors.New("task watchdog timeout exceeded")