Optimize agent planning and subagent runtime

This commit is contained in:
lpf
2026-03-09 12:33:00 +08:00
parent acf8a22c0a
commit d1abd73e63
11 changed files with 984 additions and 108 deletions

View File

@@ -67,6 +67,24 @@ type AgentLoop struct {
subagentConfigTool *tools.SubagentConfigTool
nodeRouter *nodes.Router
configPath string
subagentDigestMu sync.Mutex
subagentDigestDelay time.Duration
subagentDigests map[string]*subagentDigestState
}
type subagentDigestItem struct {
agentID string
reason string
status string
taskSummary string
resultSummary string
}
type subagentDigestState struct {
channel string
chatID string
items map[string]subagentDigestItem
dueAt time.Time
}
func (al *AgentLoop) SetConfigPath(path string) {
@@ -320,7 +338,10 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
subagentRouter: subagentRouter,
subagentConfigTool: subagentConfigTool,
nodeRouter: nodesRouter,
subagentDigestDelay: 5 * time.Second,
subagentDigests: map[string]*subagentDigestState{},
}
go loop.runSubagentDigestTicker()
// Initialize provider fallback chain (primary + proxy_fallbacks).
loop.providerPool = map[string]providers.LLMProvider{}
loop.providerNames = []string{}
@@ -371,7 +392,31 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
sessionKey = fmt.Sprintf("subagent:%s", strings.TrimSpace(task.ID))
}
taskInput := loop.buildSubagentTaskInput(task)
return loop.ProcessDirectWithOptions(ctx, taskInput, sessionKey, task.OriginChannel, task.OriginChatID, task.MemoryNS, task.ToolAllowlist)
ns := normalizeMemoryNamespace(task.MemoryNS)
ctx = withMemoryNamespaceContext(ctx, ns)
ctx = withToolAllowlistContext(ctx, task.ToolAllowlist)
channel := strings.TrimSpace(task.OriginChannel)
if channel == "" {
channel = "cli"
}
chatID := strings.TrimSpace(task.OriginChatID)
if chatID == "" {
chatID = "direct"
}
msg := bus.InboundMessage{
Channel: channel,
SenderID: "subagent",
ChatID: chatID,
Content: taskInput,
SessionKey: sessionKey,
Metadata: map[string]string{
"memory_namespace": ns,
"memory_ns": ns,
"disable_planning": "true",
"trigger": "subagent",
},
}
return loop.processMessage(ctx, msg)
})
return loop
@@ -1349,6 +1394,10 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
"chat_id": msg.ChatID,
})
if al.handleSubagentSystemMessage(msg) {
return "", nil
}
originChannel, originChatID := resolveSystemOrigin(msg.ChatID)
// Use the origin session for context
@@ -1491,6 +1540,36 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
return finalContent, nil
}
func (al *AgentLoop) handleSubagentSystemMessage(msg bus.InboundMessage) bool {
if !isSubagentSystemMessage(msg) {
return false
}
reason := ""
agentID := ""
status := ""
taskSummary := ""
resultSummary := ""
if msg.Metadata != nil {
reason = strings.ToLower(strings.TrimSpace(msg.Metadata["notify_reason"]))
agentID = strings.TrimSpace(msg.Metadata["agent_id"])
status = strings.TrimSpace(msg.Metadata["status"])
}
if agentID == "" {
agentID = strings.TrimSpace(strings.TrimPrefix(msg.SenderID, "subagent:"))
}
if taskSummary == "" || resultSummary == "" {
taskSummary, resultSummary = parseSubagentSystemContent(msg.Content)
}
al.enqueueSubagentDigest(msg, subagentDigestItem{
agentID: agentID,
reason: reason,
status: status,
taskSummary: taskSummary,
resultSummary: resultSummary,
})
return true
}
// truncate returns a truncated version of s with at most maxLen characters.
// If the string is truncated, "..." is appended to indicate truncation.
// If the string fits within maxLen, it is returned unchanged.
@@ -2219,6 +2298,219 @@ func fallbackSubagentNotification(msg bus.InboundMessage) (string, bool) {
return content, true
}
func parseSubagentSystemContent(content string) (string, string) {
content = strings.TrimSpace(content)
if content == "" {
return "", ""
}
lines := strings.Split(content, "\n")
taskSummary := ""
resultSummary := ""
for _, line := range lines {
t := strings.TrimSpace(line)
if t == "" {
continue
}
lower := strings.ToLower(t)
switch {
case strings.HasPrefix(lower, "task:"):
taskSummary = strings.TrimSpace(t[len("task:"):])
case strings.HasPrefix(lower, "summary:"):
resultSummary = strings.TrimSpace(t[len("summary:"):])
}
}
if taskSummary == "" && len(lines) > 0 {
taskSummary = summarizeSystemNotificationText(content, 120)
}
return taskSummary, resultSummary
}
func summarizeSystemNotificationText(s string, max int) string {
s = strings.TrimSpace(strings.ReplaceAll(s, "\r\n", "\n"))
s = strings.ReplaceAll(s, "\n", " ")
s = strings.Join(strings.Fields(s), " ")
if s == "" {
return ""
}
if max > 0 && len(s) > max {
return strings.TrimSpace(s[:max-3]) + "..."
}
return s
}
func (al *AgentLoop) enqueueSubagentDigest(msg bus.InboundMessage, item subagentDigestItem) {
if al == nil || al.bus == nil {
return
}
originChannel, originChatID := resolveSystemOrigin(msg.ChatID)
key := originChannel + "\x00" + originChatID
delay := al.subagentDigestDelay
if delay <= 0 {
delay = 5 * time.Second
}
al.subagentDigestMu.Lock()
state, ok := al.subagentDigests[key]
if !ok || state == nil {
state = &subagentDigestState{
channel: originChannel,
chatID: originChatID,
items: map[string]subagentDigestItem{},
}
al.subagentDigests[key] = state
}
itemKey := subagentDigestItemKey(item)
state.items[itemKey] = item
state.dueAt = time.Now().Add(delay)
al.subagentDigestMu.Unlock()
}
func subagentDigestItemKey(item subagentDigestItem) string {
agentID := strings.ToLower(strings.TrimSpace(item.agentID))
reason := strings.ToLower(strings.TrimSpace(item.reason))
task := strings.ToLower(strings.TrimSpace(item.taskSummary))
if agentID == "" {
agentID = "subagent"
}
return agentID + "\x00" + reason + "\x00" + task
}
func (al *AgentLoop) flushSubagentDigest(key string) {
if al == nil || al.bus == nil {
return
}
al.subagentDigestMu.Lock()
state := al.subagentDigests[key]
delete(al.subagentDigests, key)
al.subagentDigestMu.Unlock()
if state == nil || len(state.items) == 0 {
return
}
content := formatSubagentDigestSummary(state.items)
if strings.TrimSpace(content) == "" {
return
}
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: state.channel,
ChatID: state.chatID,
Content: content,
})
}
func (al *AgentLoop) runSubagentDigestTicker() {
if al == nil {
return
}
tick := tools.GlobalWatchdogTick
if tick <= 0 {
tick = time.Second
}
ticker := time.NewTicker(tick)
defer ticker.Stop()
for now := range ticker.C {
al.flushDueSubagentDigests(now)
}
}
func (al *AgentLoop) flushDueSubagentDigests(now time.Time) {
if al == nil || al.bus == nil {
return
}
dueKeys := make([]string, 0, 4)
al.subagentDigestMu.Lock()
for key, state := range al.subagentDigests {
if state == nil || state.dueAt.IsZero() || now.Before(state.dueAt) {
continue
}
dueKeys = append(dueKeys, key)
}
al.subagentDigestMu.Unlock()
for _, key := range dueKeys {
al.flushSubagentDigest(key)
}
}
func formatSubagentDigestSummary(items map[string]subagentDigestItem) string {
if len(items) == 0 {
return ""
}
list := make([]subagentDigestItem, 0, len(items))
completed := 0
blocked := 0
milestone := 0
failed := 0
for _, item := range items {
list = append(list, item)
switch {
case strings.EqualFold(strings.TrimSpace(item.reason), "blocked"):
blocked++
case strings.EqualFold(strings.TrimSpace(item.reason), "milestone"):
milestone++
case strings.EqualFold(strings.TrimSpace(item.status), "failed"):
failed++
default:
completed++
}
}
sort.Slice(list, func(i, j int) bool {
left := strings.TrimSpace(list[i].agentID)
right := strings.TrimSpace(list[j].agentID)
if left == right {
return strings.TrimSpace(list[i].taskSummary) < strings.TrimSpace(list[j].taskSummary)
}
return left < right
})
var sb strings.Builder
sb.WriteString("阶段总结")
stats := make([]string, 0, 4)
if completed > 0 {
stats = append(stats, fmt.Sprintf("完成 %d", completed))
}
if blocked > 0 {
stats = append(stats, fmt.Sprintf("受阻 %d", blocked))
}
if failed > 0 {
stats = append(stats, fmt.Sprintf("失败 %d", failed))
}
if milestone > 0 {
stats = append(stats, fmt.Sprintf("进展 %d", milestone))
}
if len(stats) > 0 {
sb.WriteString("" + strings.Join(stats, "") + "")
}
sb.WriteString("\n")
for _, item := range list {
agentLabel := strings.TrimSpace(item.agentID)
if agentLabel == "" {
agentLabel = "subagent"
}
statusText := "已完成"
switch {
case strings.EqualFold(strings.TrimSpace(item.reason), "blocked"):
statusText = "受阻"
case strings.EqualFold(strings.TrimSpace(item.reason), "milestone"):
statusText = "有进展"
case strings.EqualFold(strings.TrimSpace(item.status), "failed"):
statusText = "失败"
}
sb.WriteString("- " + agentLabel + "" + statusText)
if task := strings.TrimSpace(item.taskSummary); task != "" {
sb.WriteString(",任务:" + task)
}
if summary := strings.TrimSpace(item.resultSummary); summary != "" {
label := "摘要"
if statusText == "受阻" {
label = "原因"
} else if statusText == "有进展" {
label = "进度"
}
sb.WriteString("" + label + "" + summary)
}
sb.WriteString("\n")
}
return strings.TrimSpace(sb.String())
}
func shouldFlushTelegramStreamSnapshot(s string) bool {
s = strings.TrimRight(s, " \t")
if s == "" {

View File

@@ -1,8 +1,10 @@
package agent
import (
"context"
"strings"
"testing"
"time"
"clawgo/pkg/bus"
)
@@ -66,3 +68,92 @@ func TestPrepareOutboundSubagentNoReplyFallbackWithMissingOrigin(t *testing.T) {
t.Fatalf("unexpected fallback content: %q", outbound.Content)
}
}
func TestProcessSystemMessageSubagentBlockedQueuedIntoDigest(t *testing.T) {
msgBus := bus.NewMessageBus()
al := &AgentLoop{
bus: msgBus,
subagentDigestDelay: 10 * time.Millisecond,
subagentDigests: map[string]*subagentDigestState{},
}
out, err := al.processSystemMessage(context.Background(), bus.InboundMessage{
Channel: "system",
SenderID: "subagent:subagent-3",
ChatID: "telegram:9527",
Content: "Subagent update\nagent: coder\nrun: subagent-3\nstatus: blocked\nreason: blocked\ntask: 修复登录\nsummary: rate limit",
Metadata: map[string]string{
"trigger": "subagent",
"agent_id": "coder",
"status": "failed",
"notify_reason": "blocked",
},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if out != "" {
t.Fatalf("expected queued digest with no immediate output, got %q", out)
}
al.flushDueSubagentDigests(time.Now().Add(20 * time.Millisecond))
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
outbound, ok := msgBus.SubscribeOutbound(ctx)
if !ok {
t.Fatalf("expected outbound digest")
}
if !strings.Contains(outbound.Content, "阶段总结") || !strings.Contains(outbound.Content, "coder") || !strings.Contains(outbound.Content, "受阻") {
t.Fatalf("unexpected digest content: %q", outbound.Content)
}
}
func TestProcessSystemMessageSubagentDigestMergesMultipleUpdates(t *testing.T) {
msgBus := bus.NewMessageBus()
al := &AgentLoop{
bus: msgBus,
subagentDigestDelay: 10 * time.Millisecond,
subagentDigests: map[string]*subagentDigestState{},
}
first := bus.InboundMessage{
Channel: "system",
SenderID: "subagent:subagent-7",
ChatID: "telegram:9527",
Content: "Subagent update\nagent: tester\nrun: subagent-7\nstatus: completed\nreason: final\ntask: 回归测试\nsummary: 所有测试通过",
Metadata: map[string]string{
"trigger": "subagent",
"agent_id": "tester",
"status": "completed",
"notify_reason": "final",
},
}
second := bus.InboundMessage{
Channel: "system",
SenderID: "subagent:subagent-8",
ChatID: "telegram:9527",
Content: "Subagent update\nagent: coder\nrun: subagent-8\nstatus: completed\nreason: final\ntask: 修复登录\nsummary: 接口已联调",
Metadata: map[string]string{
"trigger": "subagent",
"agent_id": "coder",
"status": "completed",
"notify_reason": "final",
},
}
if out, err := al.processSystemMessage(context.Background(), first); err != nil || out != "" {
t.Fatalf("unexpected first result out=%q err=%v", out, err)
}
if out, err := al.processSystemMessage(context.Background(), second); err != nil || out != "" {
t.Fatalf("unexpected second result out=%q err=%v", out, err)
}
al.flushDueSubagentDigests(time.Now().Add(20 * time.Millisecond))
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
outbound, ok := msgBus.SubscribeOutbound(ctx)
if !ok {
t.Fatalf("expected merged outbound digest")
}
if strings.Count(outbound.Content, "\n- ") != 2 {
t.Fatalf("expected two digest lines, got: %q", outbound.Content)
}
if !strings.Contains(outbound.Content, "完成 2") {
t.Fatalf("expected aggregate completion count, got: %q", outbound.Content)
}
}

View File

@@ -38,10 +38,11 @@ func (al *AgentLoop) maybeAutoRoute(ctx context.Context, msg bus.InboundMessage)
waitCtx, cancel := context.WithTimeout(ctx, time.Duration(waitTimeout)*time.Second)
defer cancel()
task, err := al.subagentRouter.DispatchTask(waitCtx, tools.RouterDispatchRequest{
Task: taskText,
AgentID: agentID,
OriginChannel: msg.Channel,
OriginChatID: msg.ChatID,
Task: taskText,
AgentID: agentID,
NotifyMainPolicy: "internal_only",
OriginChannel: msg.Channel,
OriginChatID: msg.ChatID,
})
if err != nil {
return "", true, err

View File

@@ -76,20 +76,21 @@ func (al *AgentLoop) HandleSubagentRuntime(ctx context.Context, action string, a
return nil, fmt.Errorf("task is required")
}
task, err := router.DispatchTask(ctx, tools.RouterDispatchRequest{
Task: taskInput,
Label: runtimeStringArg(args, "label"),
Role: runtimeStringArg(args, "role"),
AgentID: runtimeStringArg(args, "agent_id"),
ThreadID: runtimeStringArg(args, "thread_id"),
CorrelationID: runtimeStringArg(args, "correlation_id"),
ParentRunID: runtimeStringArg(args, "parent_run_id"),
OriginChannel: fallbackString(runtimeStringArg(args, "channel"), "webui"),
OriginChatID: fallbackString(runtimeStringArg(args, "chat_id"), "webui"),
MaxRetries: runtimeIntArg(args, "max_retries", 0),
RetryBackoff: runtimeIntArg(args, "retry_backoff_ms", 0),
TimeoutSec: runtimeIntArg(args, "timeout_sec", 0),
MaxTaskChars: runtimeIntArg(args, "max_task_chars", 0),
MaxResultChars: runtimeIntArg(args, "max_result_chars", 0),
Task: taskInput,
Label: runtimeStringArg(args, "label"),
Role: runtimeStringArg(args, "role"),
AgentID: runtimeStringArg(args, "agent_id"),
NotifyMainPolicy: "internal_only",
ThreadID: runtimeStringArg(args, "thread_id"),
CorrelationID: runtimeStringArg(args, "correlation_id"),
ParentRunID: runtimeStringArg(args, "parent_run_id"),
OriginChannel: fallbackString(runtimeStringArg(args, "channel"), "webui"),
OriginChatID: fallbackString(runtimeStringArg(args, "chat_id"), "webui"),
MaxRetries: runtimeIntArg(args, "max_retries", 0),
RetryBackoff: runtimeIntArg(args, "retry_backoff_ms", 0),
TimeoutSec: runtimeIntArg(args, "timeout_sec", 0),
MaxTaskChars: runtimeIntArg(args, "max_task_chars", 0),
MaxResultChars: runtimeIntArg(args, "max_result_chars", 0),
})
if err != nil {
return nil, err

View File

@@ -12,9 +12,11 @@ import (
"regexp"
"strings"
"sync"
"time"
"clawgo/pkg/bus"
"clawgo/pkg/ekg"
"clawgo/pkg/providers"
"clawgo/pkg/scheduling"
)
@@ -35,18 +37,23 @@ type plannedTaskResult struct {
var reLeadingNumber = regexp.MustCompile(`^\d+[\.)、]\s*`)
func (al *AgentLoop) processPlannedMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
tasks := al.planSessionTasks(msg)
tasks := al.planSessionTasks(ctx, msg)
if len(tasks) <= 1 {
return al.processMessage(ctx, msg)
}
return al.runPlannedTasks(ctx, msg, tasks)
}
func (al *AgentLoop) planSessionTasks(msg bus.InboundMessage) []plannedTask {
func (al *AgentLoop) planSessionTasks(ctx context.Context, msg bus.InboundMessage) []plannedTask {
base := strings.TrimSpace(msg.Content)
if base == "" {
return nil
}
if msg.Metadata != nil {
if planningDisabled(msg.Metadata["disable_planning"]) {
return []plannedTask{{Index: 1, Content: base, ResourceKeys: scheduling.DeriveResourceKeys(base)}}
}
}
if msg.Channel == "system" || msg.Channel == "internal" {
return []plannedTask{{Index: 1, Content: base, ResourceKeys: scheduling.DeriveResourceKeys(base)}}
}
@@ -63,6 +70,9 @@ func (al *AgentLoop) planSessionTasks(msg bus.InboundMessage) []plannedTask {
if len(segments) <= 1 {
return []plannedTask{{Index: 1, Content: base, ResourceKeys: scheduling.DeriveResourceKeys(base)}}
}
if refined, ok := al.inferPlannedSegments(ctx, base, segments); ok {
segments = refined
}
out := make([]plannedTask, 0, len(segments))
for i, seg := range segments {
@@ -86,6 +96,125 @@ func (al *AgentLoop) planSessionTasks(msg bus.InboundMessage) []plannedTask {
return out
}
type plannerDecision struct {
ShouldSplit bool `json:"should_split"`
Tasks []string `json:"tasks"`
}
func (al *AgentLoop) inferPlannedSegments(ctx context.Context, content string, candidates []string) ([]string, bool) {
if al == nil || al.provider == nil {
return nil, false
}
content = strings.TrimSpace(content)
if content == "" || len(candidates) <= 1 {
return nil, false
}
plannerCtx := ctx
if plannerCtx == nil {
plannerCtx = context.Background()
}
var cancel context.CancelFunc
plannerCtx, cancel = context.WithTimeout(plannerCtx, 8*time.Second)
defer cancel()
previewCandidates := candidates
if len(previewCandidates) > 12 {
previewCandidates = previewCandidates[:12]
}
var candidateList strings.Builder
for i, item := range previewCandidates {
candidateList.WriteString(fmt.Sprintf("%d. %s\n", i+1, strings.TrimSpace(item)))
}
resp, err := al.provider.Chat(plannerCtx, []providers.Message{
{
Role: "system",
Content: "Decide whether the request should stay as one task or be split into a small number of high-level task groups. " +
"Default to no split. Only split when the request contains multiple independent deliverables, roles, or workstreams. " +
"Never split into fine-grained execution steps. Merge related steps. Return strict JSON only: " +
`{"should_split":true|false,"tasks":["..."]}` +
" If should_split is false, tasks must be empty. If true, tasks must contain 2 to 8 concise high-level task groups.",
},
{
Role: "user",
Content: fmt.Sprintf("Original request:\n%s\n\nRule-based candidate segments (%d shown of %d):\n%s",
content,
len(previewCandidates),
len(candidates),
strings.TrimSpace(candidateList.String()),
),
},
}, nil, al.provider.GetDefaultModel(), map[string]interface{}{
"max_tokens": 256,
})
if err != nil || resp == nil {
return nil, false
}
decision, ok := parsePlannerDecision(resp.Content)
if !ok {
return nil, false
}
if !decision.ShouldSplit {
return []string{content}, true
}
tasks := sanitizePlannerTasks(decision.Tasks)
if len(tasks) < 2 || len(tasks) > 8 {
return []string{content}, true
}
return tasks, true
}
func parsePlannerDecision(raw string) (plannerDecision, bool) {
raw = strings.TrimSpace(raw)
if raw == "" {
return plannerDecision{}, false
}
if fenced := extractJSONObject(raw); fenced != "" {
raw = fenced
}
var out plannerDecision
if err := json.Unmarshal([]byte(raw), &out); err != nil {
return plannerDecision{}, false
}
return out, true
}
func extractJSONObject(raw string) string {
start := strings.Index(raw, "{")
end := strings.LastIndex(raw, "}")
if start < 0 || end <= start {
return ""
}
return strings.TrimSpace(raw[start : end+1])
}
func sanitizePlannerTasks(items []string) []string {
out := make([]string, 0, len(items))
seen := map[string]struct{}{}
for _, item := range items {
t := strings.TrimSpace(reLeadingNumber.ReplaceAllString(strings.TrimSpace(item), ""))
if t == "" {
continue
}
key := strings.ToLower(t)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, t)
}
return out
}
func planningDisabled(v string) bool {
switch strings.ToLower(strings.TrimSpace(v)) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
func splitPlannedSegments(content string) []string {
lines := strings.Split(content, "\n")
bullet := make([]string, 0, len(lines))

View File

@@ -1,6 +1,12 @@
package agent
import "testing"
import (
"context"
"testing"
"clawgo/pkg/bus"
"clawgo/pkg/providers"
)
func TestSplitPlannedSegmentsDoesNotSplitPlainNewlines(t *testing.T) {
t.Parallel()
@@ -31,3 +37,75 @@ func TestSplitPlannedSegmentsStillSplitsSemicolons(t *testing.T) {
t.Fatalf("expected 2 segments, got %d: %#v", len(got), got)
}
}
func TestPlanSessionTasksDisablePlanningMetadataSkipsBulletSplit(t *testing.T) {
t.Parallel()
al := &AgentLoop{}
msg := bus.InboundMessage{
Content: "Role Profile Policy:\n1. 先做代码审查\n2. 再执行测试\n\nTask:\n修复登录流程",
Metadata: map[string]string{
"disable_planning": "true",
},
}
got := al.planSessionTasks(context.Background(), msg)
if len(got) != 1 {
t.Fatalf("expected 1 segment, got %d: %#v", len(got), got)
}
if got[0].Content != msg.Content {
t.Fatalf("expected original content to be preserved, got: %#v", got[0].Content)
}
}
func TestPlanSessionTasksUsesAIPlannerDecisionToAvoidOversplitting(t *testing.T) {
t.Parallel()
al := &AgentLoop{
provider: plannerStubProvider{content: `{"should_split":false,"tasks":[]}`},
}
msg := bus.InboundMessage{
Content: "1. 调研\n2. 写方案\n3. 设计数据库\n4. 写接口\n5. 写前端\n6. 写测试",
}
got := al.planSessionTasks(context.Background(), msg)
if len(got) != 1 {
t.Fatalf("expected AI planner to keep one task, got %d: %#v", len(got), got)
}
if got[0].Content != msg.Content {
t.Fatalf("expected original content, got %#v", got[0].Content)
}
}
func TestPlanSessionTasksUsesAIPlannerDecisionToSplitIntoHighLevelGroups(t *testing.T) {
t.Parallel()
al := &AgentLoop{
provider: plannerStubProvider{content: `{"should_split":true,"tasks":["产品方案","研发实现","测试验收"]}`},
}
msg := bus.InboundMessage{
Content: "1. 调研\n2. 写方案\n3. 设计数据库\n4. 写接口\n5. 写前端\n6. 写测试",
}
got := al.planSessionTasks(context.Background(), msg)
if len(got) != 3 {
t.Fatalf("expected 3 planned tasks, got %d: %#v", len(got), got)
}
if got[0].Content != "产品方案" || got[1].Content != "研发实现" || got[2].Content != "测试验收" {
t.Fatalf("unexpected planned task contents: %#v", got)
}
}
type plannerStubProvider struct {
content string
err error
}
func (p plannerStubProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
if p.err != nil {
return nil, p.err
}
return &providers.LLMResponse{Content: p.content}, nil
}
func (p plannerStubProvider) GetDefaultModel() string { return "stub" }

View File

@@ -64,6 +64,11 @@ type Server struct {
ekgCacheStamp time.Time
ekgCacheSize int64
ekgCacheRows []map[string]interface{}
liveRuntimeMu sync.Mutex
liveRuntimeSubs map[chan []byte]struct{}
liveRuntimeOn bool
liveSubagentMu sync.Mutex
liveSubagents map[string]*liveSubagentGroup
}
var nodesWebsocketUpgrader = websocket.Upgrader{
@@ -79,12 +84,14 @@ func NewServer(host string, port int, token string, mgr *nodes.Manager) *Server
port = 7788
}
return &Server{
addr: fmt.Sprintf("%s:%d", addr, port),
token: strings.TrimSpace(token),
mgr: mgr,
nodeConnIDs: map[string]string{},
nodeSockets: map[string]*nodeSocketConn{},
artifactStats: map[string]interface{}{},
addr: fmt.Sprintf("%s:%d", addr, port),
token: strings.TrimSpace(token),
mgr: mgr,
nodeConnIDs: map[string]string{},
nodeSockets: map[string]*nodeSocketConn{},
artifactStats: map[string]interface{}{},
liveRuntimeSubs: map[chan []byte]struct{}{},
liveSubagents: map[string]*liveSubagentGroup{},
}
}
@@ -94,6 +101,13 @@ type nodeSocketConn struct {
mu sync.Mutex
}
type liveSubagentGroup struct {
taskID string
previewTaskID string
subs map[chan []byte]struct{}
stopCh chan struct{}
}
func (c *nodeSocketConn) Send(msg nodes.WireMessage) error {
if c == nil || c.conn == nil {
return fmt.Errorf("node websocket unavailable")
@@ -104,6 +118,168 @@ func (c *nodeSocketConn) Send(msg nodes.WireMessage) error {
return c.conn.WriteJSON(msg)
}
func publishLiveSnapshot(subs map[chan []byte]struct{}, payload []byte) {
for ch := range subs {
select {
case ch <- payload:
default:
select {
case <-ch:
default:
}
select {
case ch <- payload:
default:
}
}
}
}
func (s *Server) subscribeRuntimeLive(ctx context.Context) chan []byte {
ch := make(chan []byte, 1)
s.liveRuntimeMu.Lock()
s.liveRuntimeSubs[ch] = struct{}{}
start := !s.liveRuntimeOn
if start {
s.liveRuntimeOn = true
}
s.liveRuntimeMu.Unlock()
if start {
go s.runtimeLiveLoop()
}
go func() {
<-ctx.Done()
s.unsubscribeRuntimeLive(ch)
}()
return ch
}
func (s *Server) unsubscribeRuntimeLive(ch chan []byte) {
s.liveRuntimeMu.Lock()
delete(s.liveRuntimeSubs, ch)
s.liveRuntimeMu.Unlock()
}
func (s *Server) runtimeLiveLoop() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
if !s.publishRuntimeSnapshot(context.Background()) {
s.liveRuntimeMu.Lock()
if len(s.liveRuntimeSubs) == 0 {
s.liveRuntimeOn = false
s.liveRuntimeMu.Unlock()
return
}
s.liveRuntimeMu.Unlock()
}
<-ticker.C
}
}
func (s *Server) publishRuntimeSnapshot(ctx context.Context) bool {
if s == nil {
return false
}
payload := map[string]interface{}{
"ok": true,
"type": "runtime_snapshot",
"snapshot": s.buildWebUIRuntimeSnapshot(ctx),
}
data, err := json.Marshal(payload)
if err != nil {
return false
}
s.liveRuntimeMu.Lock()
defer s.liveRuntimeMu.Unlock()
if len(s.liveRuntimeSubs) == 0 {
return false
}
publishLiveSnapshot(s.liveRuntimeSubs, data)
return true
}
func buildSubagentLiveKey(taskID, previewTaskID string) string {
return strings.TrimSpace(taskID) + "\x00" + strings.TrimSpace(previewTaskID)
}
func (s *Server) subscribeSubagentLive(ctx context.Context, taskID, previewTaskID string) chan []byte {
ch := make(chan []byte, 1)
key := buildSubagentLiveKey(taskID, previewTaskID)
s.liveSubagentMu.Lock()
group := s.liveSubagents[key]
if group == nil {
group = &liveSubagentGroup{
taskID: strings.TrimSpace(taskID),
previewTaskID: strings.TrimSpace(previewTaskID),
subs: map[chan []byte]struct{}{},
stopCh: make(chan struct{}),
}
s.liveSubagents[key] = group
go s.subagentLiveLoop(key, group)
}
group.subs[ch] = struct{}{}
s.liveSubagentMu.Unlock()
go func() {
<-ctx.Done()
s.unsubscribeSubagentLive(key, ch)
}()
return ch
}
func (s *Server) unsubscribeSubagentLive(key string, ch chan []byte) {
s.liveSubagentMu.Lock()
group := s.liveSubagents[key]
if group == nil {
s.liveSubagentMu.Unlock()
return
}
delete(group.subs, ch)
if len(group.subs) == 0 {
delete(s.liveSubagents, key)
close(group.stopCh)
}
s.liveSubagentMu.Unlock()
}
func (s *Server) subagentLiveLoop(key string, group *liveSubagentGroup) {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
if !s.publishSubagentLiveSnapshot(context.Background(), key, group.taskID, group.previewTaskID) {
return
}
select {
case <-group.stopCh:
return
case <-ticker.C:
}
}
}
func (s *Server) publishSubagentLiveSnapshot(ctx context.Context, key, taskID, previewTaskID string) bool {
if s == nil {
return false
}
payload := map[string]interface{}{
"ok": true,
"type": "subagents_live",
"payload": s.buildSubagentsLivePayload(ctx, taskID, previewTaskID),
}
data, err := json.Marshal(payload)
if err != nil {
return false
}
s.liveSubagentMu.Lock()
defer s.liveSubagentMu.Unlock()
group := s.liveSubagents[key]
if group == nil || len(group.subs) == 0 {
return false
}
publishLiveSnapshot(group.subs, data)
return true
}
func (s *Server) SetConfigPath(path string) { s.configPath = strings.TrimSpace(path) }
func (s *Server) SetWorkspacePath(path string) { s.workspacePath = strings.TrimSpace(path) }
func (s *Server) SetLogFilePath(path string) { s.logFilePath = strings.TrimSpace(path) }
@@ -977,28 +1153,23 @@ func (s *Server) handleWebUIRuntime(w http.ResponseWriter, r *http.Request) {
defer conn.Close()
ctx := r.Context()
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
sendSnapshot := func() error {
payload := map[string]interface{}{
"ok": true,
"type": "runtime_snapshot",
"snapshot": s.buildWebUIRuntimeSnapshot(ctx),
}
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
return conn.WriteJSON(payload)
sub := s.subscribeRuntimeLive(ctx)
initial := map[string]interface{}{
"ok": true,
"type": "runtime_snapshot",
"snapshot": s.buildWebUIRuntimeSnapshot(ctx),
}
if err := sendSnapshot(); err != nil {
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := conn.WriteJSON(initial); err != nil {
return
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := sendSnapshot(); err != nil {
case payload := <-sub:
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := conn.WriteMessage(websocket.TextMessage, payload); err != nil {
return
}
}
@@ -4245,28 +4416,23 @@ func (s *Server) handleWebUISubagentsRuntimeLive(w http.ResponseWriter, r *http.
ctx := r.Context()
taskID := strings.TrimSpace(r.URL.Query().Get("task_id"))
previewTaskID := strings.TrimSpace(r.URL.Query().Get("preview_task_id"))
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
sendSnapshot := func() error {
payload := map[string]interface{}{
"ok": true,
"type": "subagents_live",
"payload": s.buildSubagentsLivePayload(ctx, taskID, previewTaskID),
}
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
return conn.WriteJSON(payload)
sub := s.subscribeSubagentLive(ctx, taskID, previewTaskID)
initial := map[string]interface{}{
"ok": true,
"type": "subagents_live",
"payload": s.buildSubagentsLivePayload(ctx, taskID, previewTaskID),
}
if err := sendSnapshot(); err != nil {
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := conn.WriteJSON(initial); err != nil {
return
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := sendSnapshot(); err != nil {
case payload := <-sub:
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := conn.WriteMessage(websocket.TextMessage, payload); err != nil {
return
}
}

View File

@@ -53,6 +53,7 @@ type SubagentTask struct {
type SubagentManager struct {
tasks map[string]*SubagentTask
cancelFuncs map[string]context.CancelFunc
waiters map[string]map[chan struct{}]struct{}
recoverableTaskIDs []string
archiveAfterMinute int64
mu sync.RWMutex
@@ -92,6 +93,7 @@ func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *b
mgr := &SubagentManager{
tasks: make(map[string]*SubagentTask),
cancelFuncs: make(map[string]context.CancelFunc),
waiters: make(map[string]map[chan struct{}]struct{}),
archiveAfterMinute: 60,
provider: provider,
bus: bus,
@@ -356,6 +358,7 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
CreatedAt: task.Updated,
})
sm.persistTaskLocked(task, "completed", task.Result)
sm.notifyTaskWaitersLocked(task.ID)
} else {
task.Status = "completed"
task.Result = applySubagentResultQuota(result, task.MaxResultChars)
@@ -373,6 +376,7 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
CreatedAt: task.Updated,
})
sm.persistTaskLocked(task, "completed", task.Result)
sm.notifyTaskWaitersLocked(task.ID)
}
sm.mu.Unlock()
@@ -790,6 +794,7 @@ func (sm *SubagentManager) KillTask(taskID string) bool {
t.WaitingReply = false
t.Updated = time.Now().UnixMilli()
sm.persistTaskLocked(t, "killed", "")
sm.notifyTaskWaitersLocked(taskID)
}
return true
}
@@ -1013,6 +1018,96 @@ func (sm *SubagentManager) persistTaskLocked(task *SubagentTask, eventType, mess
})
}
func (sm *SubagentManager) WaitTask(ctx context.Context, taskID string) (*SubagentTask, bool, error) {
if sm == nil {
return nil, false, fmt.Errorf("subagent manager not available")
}
taskID = strings.TrimSpace(taskID)
if taskID == "" {
return nil, false, fmt.Errorf("task id is required")
}
if ctx == nil {
ctx = context.Background()
}
ch := make(chan struct{}, 1)
sm.mu.Lock()
sm.pruneArchivedLocked()
task, ok := sm.tasks[taskID]
if !ok && sm.runStore != nil {
if persisted, found := sm.runStore.Get(taskID); found && persisted != nil {
if strings.TrimSpace(persisted.Status) != "running" {
sm.mu.Unlock()
return persisted, true, nil
}
}
}
if ok && task != nil && strings.TrimSpace(task.Status) != "running" {
cp := cloneSubagentTask(task)
sm.mu.Unlock()
return cp, true, nil
}
waiters := sm.waiters[taskID]
if waiters == nil {
waiters = map[chan struct{}]struct{}{}
sm.waiters[taskID] = waiters
}
waiters[ch] = struct{}{}
sm.mu.Unlock()
defer sm.removeTaskWaiter(taskID, ch)
for {
select {
case <-ctx.Done():
return nil, false, ctx.Err()
case <-ch:
sm.mu.Lock()
sm.pruneArchivedLocked()
task, ok := sm.tasks[taskID]
if ok && task != nil && strings.TrimSpace(task.Status) != "running" {
cp := cloneSubagentTask(task)
sm.mu.Unlock()
return cp, true, nil
}
if !ok && sm.runStore != nil {
if persisted, found := sm.runStore.Get(taskID); found && persisted != nil && strings.TrimSpace(persisted.Status) != "running" {
sm.mu.Unlock()
return persisted, true, nil
}
}
sm.mu.Unlock()
}
}
}
func (sm *SubagentManager) removeTaskWaiter(taskID string, ch chan struct{}) {
sm.mu.Lock()
defer sm.mu.Unlock()
waiters := sm.waiters[taskID]
if len(waiters) == 0 {
delete(sm.waiters, taskID)
return
}
delete(waiters, ch)
if len(waiters) == 0 {
delete(sm.waiters, taskID)
}
}
func (sm *SubagentManager) notifyTaskWaitersLocked(taskID string) {
waiters := sm.waiters[taskID]
if len(waiters) == 0 {
delete(sm.waiters, taskID)
return
}
for ch := range waiters {
select {
case ch <- struct{}{}:
default:
}
}
delete(sm.waiters, taskID)
}
func (sm *SubagentManager) recordMailboxMessageLocked(task *SubagentTask, msg AgentMessage) {
if sm.mailboxStore == nil || task == nil {
return

View File

@@ -8,20 +8,21 @@ import (
)
type RouterDispatchRequest struct {
Task string
Label string
Role string
AgentID string
ThreadID string
CorrelationID string
ParentRunID string
OriginChannel string
OriginChatID string
MaxRetries int
RetryBackoff int
TimeoutSec int
MaxTaskChars int
MaxResultChars int
Task string
Label string
Role string
AgentID string
NotifyMainPolicy string
ThreadID string
CorrelationID string
ParentRunID string
OriginChannel string
OriginChatID string
MaxRetries int
RetryBackoff int
TimeoutSec int
MaxTaskChars int
MaxResultChars int
}
type RouterReply struct {
@@ -46,20 +47,21 @@ func (r *SubagentRouter) DispatchTask(ctx context.Context, req RouterDispatchReq
return nil, fmt.Errorf("subagent router is not configured")
}
task, err := r.manager.SpawnTask(ctx, SubagentSpawnOptions{
Task: req.Task,
Label: req.Label,
Role: req.Role,
AgentID: req.AgentID,
ThreadID: req.ThreadID,
CorrelationID: req.CorrelationID,
ParentRunID: req.ParentRunID,
OriginChannel: req.OriginChannel,
OriginChatID: req.OriginChatID,
MaxRetries: req.MaxRetries,
RetryBackoff: req.RetryBackoff,
TimeoutSec: req.TimeoutSec,
MaxTaskChars: req.MaxTaskChars,
MaxResultChars: req.MaxResultChars,
Task: req.Task,
Label: req.Label,
Role: req.Role,
AgentID: req.AgentID,
NotifyMainPolicy: req.NotifyMainPolicy,
ThreadID: req.ThreadID,
CorrelationID: req.CorrelationID,
ParentRunID: req.ParentRunID,
OriginChannel: req.OriginChannel,
OriginChatID: req.OriginChatID,
MaxRetries: req.MaxRetries,
RetryBackoff: req.RetryBackoff,
TimeoutSec: req.TimeoutSec,
MaxTaskChars: req.MaxTaskChars,
MaxResultChars: req.MaxResultChars,
})
if err != nil {
return nil, err
@@ -71,33 +73,26 @@ func (r *SubagentRouter) WaitReply(ctx context.Context, taskID string, interval
if r == nil || r.manager == nil {
return nil, fmt.Errorf("subagent router is not configured")
}
if interval <= 0 {
interval = 100 * time.Millisecond
}
_ = interval
taskID = strings.TrimSpace(taskID)
if taskID == "" {
return nil, fmt.Errorf("task id is required")
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
task, ok := r.manager.GetTask(taskID)
if ok && task != nil && task.Status != "running" {
return &RouterReply{
TaskID: task.ID,
ThreadID: task.ThreadID,
CorrelationID: task.CorrelationID,
AgentID: task.AgentID,
Status: task.Status,
Result: strings.TrimSpace(task.Result),
}, nil
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
}
task, ok, err := r.manager.WaitTask(ctx, taskID)
if err != nil {
return nil, err
}
if !ok || task == nil {
return nil, fmt.Errorf("subagent not found")
}
return &RouterReply{
TaskID: task.ID,
ThreadID: task.ThreadID,
CorrelationID: task.CorrelationID,
AgentID: task.AgentID,
Status: task.Status,
Result: strings.TrimSpace(task.Result),
}, nil
}
func (r *SubagentRouter) MergeResults(replies []*RouterReply) string {

View File

@@ -47,3 +47,29 @@ func TestSubagentRouterMergeResults(t *testing.T) {
t.Fatalf("unexpected merged output: %s", out)
}
}
func TestSubagentRouterWaitReplyContextCancel(t *testing.T) {
workspace := t.TempDir()
manager := NewSubagentManager(nil, workspace, nil)
manager.SetRunFunc(func(ctx context.Context, task *SubagentTask) (string, error) {
<-ctx.Done()
return "", ctx.Err()
})
router := NewSubagentRouter(manager)
task, err := router.DispatchTask(context.Background(), RouterDispatchRequest{
Task: "long task",
AgentID: "coder",
OriginChannel: "cli",
OriginChatID: "direct",
})
if err != nil {
t.Fatalf("dispatch failed: %v", err)
}
waitCtx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
if _, err := router.WaitReply(waitCtx, task.ID, 20*time.Millisecond); err == nil {
t.Fatalf("expected context cancellation error")
}
}

View File

@@ -25,6 +25,8 @@ const (
maxWorldCycle = 60 * time.Second
)
const GlobalWatchdogTick = watchdogTick
var ErrCommandNoProgress = errors.New("command no progress across tick rounds")
var ErrTaskWatchdogTimeout = errors.New("task watchdog timeout exceeded")