mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-12 22:17:29 +08:00
Optimize agent planning and subagent runtime
This commit is contained in:
@@ -67,6 +67,24 @@ type AgentLoop struct {
|
||||
subagentConfigTool *tools.SubagentConfigTool
|
||||
nodeRouter *nodes.Router
|
||||
configPath string
|
||||
subagentDigestMu sync.Mutex
|
||||
subagentDigestDelay time.Duration
|
||||
subagentDigests map[string]*subagentDigestState
|
||||
}
|
||||
|
||||
type subagentDigestItem struct {
|
||||
agentID string
|
||||
reason string
|
||||
status string
|
||||
taskSummary string
|
||||
resultSummary string
|
||||
}
|
||||
|
||||
type subagentDigestState struct {
|
||||
channel string
|
||||
chatID string
|
||||
items map[string]subagentDigestItem
|
||||
dueAt time.Time
|
||||
}
|
||||
|
||||
func (al *AgentLoop) SetConfigPath(path string) {
|
||||
@@ -320,7 +338,10 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
subagentRouter: subagentRouter,
|
||||
subagentConfigTool: subagentConfigTool,
|
||||
nodeRouter: nodesRouter,
|
||||
subagentDigestDelay: 5 * time.Second,
|
||||
subagentDigests: map[string]*subagentDigestState{},
|
||||
}
|
||||
go loop.runSubagentDigestTicker()
|
||||
// Initialize provider fallback chain (primary + proxy_fallbacks).
|
||||
loop.providerPool = map[string]providers.LLMProvider{}
|
||||
loop.providerNames = []string{}
|
||||
@@ -371,7 +392,31 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
sessionKey = fmt.Sprintf("subagent:%s", strings.TrimSpace(task.ID))
|
||||
}
|
||||
taskInput := loop.buildSubagentTaskInput(task)
|
||||
return loop.ProcessDirectWithOptions(ctx, taskInput, sessionKey, task.OriginChannel, task.OriginChatID, task.MemoryNS, task.ToolAllowlist)
|
||||
ns := normalizeMemoryNamespace(task.MemoryNS)
|
||||
ctx = withMemoryNamespaceContext(ctx, ns)
|
||||
ctx = withToolAllowlistContext(ctx, task.ToolAllowlist)
|
||||
channel := strings.TrimSpace(task.OriginChannel)
|
||||
if channel == "" {
|
||||
channel = "cli"
|
||||
}
|
||||
chatID := strings.TrimSpace(task.OriginChatID)
|
||||
if chatID == "" {
|
||||
chatID = "direct"
|
||||
}
|
||||
msg := bus.InboundMessage{
|
||||
Channel: channel,
|
||||
SenderID: "subagent",
|
||||
ChatID: chatID,
|
||||
Content: taskInput,
|
||||
SessionKey: sessionKey,
|
||||
Metadata: map[string]string{
|
||||
"memory_namespace": ns,
|
||||
"memory_ns": ns,
|
||||
"disable_planning": "true",
|
||||
"trigger": "subagent",
|
||||
},
|
||||
}
|
||||
return loop.processMessage(ctx, msg)
|
||||
})
|
||||
|
||||
return loop
|
||||
@@ -1349,6 +1394,10 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
||||
"chat_id": msg.ChatID,
|
||||
})
|
||||
|
||||
if al.handleSubagentSystemMessage(msg) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
originChannel, originChatID := resolveSystemOrigin(msg.ChatID)
|
||||
|
||||
// Use the origin session for context
|
||||
@@ -1491,6 +1540,36 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
||||
return finalContent, nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) handleSubagentSystemMessage(msg bus.InboundMessage) bool {
|
||||
if !isSubagentSystemMessage(msg) {
|
||||
return false
|
||||
}
|
||||
reason := ""
|
||||
agentID := ""
|
||||
status := ""
|
||||
taskSummary := ""
|
||||
resultSummary := ""
|
||||
if msg.Metadata != nil {
|
||||
reason = strings.ToLower(strings.TrimSpace(msg.Metadata["notify_reason"]))
|
||||
agentID = strings.TrimSpace(msg.Metadata["agent_id"])
|
||||
status = strings.TrimSpace(msg.Metadata["status"])
|
||||
}
|
||||
if agentID == "" {
|
||||
agentID = strings.TrimSpace(strings.TrimPrefix(msg.SenderID, "subagent:"))
|
||||
}
|
||||
if taskSummary == "" || resultSummary == "" {
|
||||
taskSummary, resultSummary = parseSubagentSystemContent(msg.Content)
|
||||
}
|
||||
al.enqueueSubagentDigest(msg, subagentDigestItem{
|
||||
agentID: agentID,
|
||||
reason: reason,
|
||||
status: status,
|
||||
taskSummary: taskSummary,
|
||||
resultSummary: resultSummary,
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
// truncate returns a truncated version of s with at most maxLen characters.
|
||||
// If the string is truncated, "..." is appended to indicate truncation.
|
||||
// If the string fits within maxLen, it is returned unchanged.
|
||||
@@ -2219,6 +2298,219 @@ func fallbackSubagentNotification(msg bus.InboundMessage) (string, bool) {
|
||||
return content, true
|
||||
}
|
||||
|
||||
func parseSubagentSystemContent(content string) (string, string) {
|
||||
content = strings.TrimSpace(content)
|
||||
if content == "" {
|
||||
return "", ""
|
||||
}
|
||||
lines := strings.Split(content, "\n")
|
||||
taskSummary := ""
|
||||
resultSummary := ""
|
||||
for _, line := range lines {
|
||||
t := strings.TrimSpace(line)
|
||||
if t == "" {
|
||||
continue
|
||||
}
|
||||
lower := strings.ToLower(t)
|
||||
switch {
|
||||
case strings.HasPrefix(lower, "task:"):
|
||||
taskSummary = strings.TrimSpace(t[len("task:"):])
|
||||
case strings.HasPrefix(lower, "summary:"):
|
||||
resultSummary = strings.TrimSpace(t[len("summary:"):])
|
||||
}
|
||||
}
|
||||
if taskSummary == "" && len(lines) > 0 {
|
||||
taskSummary = summarizeSystemNotificationText(content, 120)
|
||||
}
|
||||
return taskSummary, resultSummary
|
||||
}
|
||||
|
||||
func summarizeSystemNotificationText(s string, max int) string {
|
||||
s = strings.TrimSpace(strings.ReplaceAll(s, "\r\n", "\n"))
|
||||
s = strings.ReplaceAll(s, "\n", " ")
|
||||
s = strings.Join(strings.Fields(s), " ")
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
if max > 0 && len(s) > max {
|
||||
return strings.TrimSpace(s[:max-3]) + "..."
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (al *AgentLoop) enqueueSubagentDigest(msg bus.InboundMessage, item subagentDigestItem) {
|
||||
if al == nil || al.bus == nil {
|
||||
return
|
||||
}
|
||||
originChannel, originChatID := resolveSystemOrigin(msg.ChatID)
|
||||
key := originChannel + "\x00" + originChatID
|
||||
delay := al.subagentDigestDelay
|
||||
if delay <= 0 {
|
||||
delay = 5 * time.Second
|
||||
}
|
||||
|
||||
al.subagentDigestMu.Lock()
|
||||
state, ok := al.subagentDigests[key]
|
||||
if !ok || state == nil {
|
||||
state = &subagentDigestState{
|
||||
channel: originChannel,
|
||||
chatID: originChatID,
|
||||
items: map[string]subagentDigestItem{},
|
||||
}
|
||||
al.subagentDigests[key] = state
|
||||
}
|
||||
itemKey := subagentDigestItemKey(item)
|
||||
state.items[itemKey] = item
|
||||
state.dueAt = time.Now().Add(delay)
|
||||
al.subagentDigestMu.Unlock()
|
||||
}
|
||||
|
||||
func subagentDigestItemKey(item subagentDigestItem) string {
|
||||
agentID := strings.ToLower(strings.TrimSpace(item.agentID))
|
||||
reason := strings.ToLower(strings.TrimSpace(item.reason))
|
||||
task := strings.ToLower(strings.TrimSpace(item.taskSummary))
|
||||
if agentID == "" {
|
||||
agentID = "subagent"
|
||||
}
|
||||
return agentID + "\x00" + reason + "\x00" + task
|
||||
}
|
||||
|
||||
func (al *AgentLoop) flushSubagentDigest(key string) {
|
||||
if al == nil || al.bus == nil {
|
||||
return
|
||||
}
|
||||
al.subagentDigestMu.Lock()
|
||||
state := al.subagentDigests[key]
|
||||
delete(al.subagentDigests, key)
|
||||
al.subagentDigestMu.Unlock()
|
||||
if state == nil || len(state.items) == 0 {
|
||||
return
|
||||
}
|
||||
content := formatSubagentDigestSummary(state.items)
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: state.channel,
|
||||
ChatID: state.chatID,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) runSubagentDigestTicker() {
|
||||
if al == nil {
|
||||
return
|
||||
}
|
||||
tick := tools.GlobalWatchdogTick
|
||||
if tick <= 0 {
|
||||
tick = time.Second
|
||||
}
|
||||
ticker := time.NewTicker(tick)
|
||||
defer ticker.Stop()
|
||||
for now := range ticker.C {
|
||||
al.flushDueSubagentDigests(now)
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) flushDueSubagentDigests(now time.Time) {
|
||||
if al == nil || al.bus == nil {
|
||||
return
|
||||
}
|
||||
dueKeys := make([]string, 0, 4)
|
||||
al.subagentDigestMu.Lock()
|
||||
for key, state := range al.subagentDigests {
|
||||
if state == nil || state.dueAt.IsZero() || now.Before(state.dueAt) {
|
||||
continue
|
||||
}
|
||||
dueKeys = append(dueKeys, key)
|
||||
}
|
||||
al.subagentDigestMu.Unlock()
|
||||
for _, key := range dueKeys {
|
||||
al.flushSubagentDigest(key)
|
||||
}
|
||||
}
|
||||
|
||||
func formatSubagentDigestSummary(items map[string]subagentDigestItem) string {
|
||||
if len(items) == 0 {
|
||||
return ""
|
||||
}
|
||||
list := make([]subagentDigestItem, 0, len(items))
|
||||
completed := 0
|
||||
blocked := 0
|
||||
milestone := 0
|
||||
failed := 0
|
||||
for _, item := range items {
|
||||
list = append(list, item)
|
||||
switch {
|
||||
case strings.EqualFold(strings.TrimSpace(item.reason), "blocked"):
|
||||
blocked++
|
||||
case strings.EqualFold(strings.TrimSpace(item.reason), "milestone"):
|
||||
milestone++
|
||||
case strings.EqualFold(strings.TrimSpace(item.status), "failed"):
|
||||
failed++
|
||||
default:
|
||||
completed++
|
||||
}
|
||||
}
|
||||
sort.Slice(list, func(i, j int) bool {
|
||||
left := strings.TrimSpace(list[i].agentID)
|
||||
right := strings.TrimSpace(list[j].agentID)
|
||||
if left == right {
|
||||
return strings.TrimSpace(list[i].taskSummary) < strings.TrimSpace(list[j].taskSummary)
|
||||
}
|
||||
return left < right
|
||||
})
|
||||
var sb strings.Builder
|
||||
sb.WriteString("阶段总结")
|
||||
stats := make([]string, 0, 4)
|
||||
if completed > 0 {
|
||||
stats = append(stats, fmt.Sprintf("完成 %d", completed))
|
||||
}
|
||||
if blocked > 0 {
|
||||
stats = append(stats, fmt.Sprintf("受阻 %d", blocked))
|
||||
}
|
||||
if failed > 0 {
|
||||
stats = append(stats, fmt.Sprintf("失败 %d", failed))
|
||||
}
|
||||
if milestone > 0 {
|
||||
stats = append(stats, fmt.Sprintf("进展 %d", milestone))
|
||||
}
|
||||
if len(stats) > 0 {
|
||||
sb.WriteString("(" + strings.Join(stats, ",") + ")")
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
for _, item := range list {
|
||||
agentLabel := strings.TrimSpace(item.agentID)
|
||||
if agentLabel == "" {
|
||||
agentLabel = "subagent"
|
||||
}
|
||||
statusText := "已完成"
|
||||
switch {
|
||||
case strings.EqualFold(strings.TrimSpace(item.reason), "blocked"):
|
||||
statusText = "受阻"
|
||||
case strings.EqualFold(strings.TrimSpace(item.reason), "milestone"):
|
||||
statusText = "有进展"
|
||||
case strings.EqualFold(strings.TrimSpace(item.status), "failed"):
|
||||
statusText = "失败"
|
||||
}
|
||||
sb.WriteString("- " + agentLabel + ":" + statusText)
|
||||
if task := strings.TrimSpace(item.taskSummary); task != "" {
|
||||
sb.WriteString(",任务:" + task)
|
||||
}
|
||||
if summary := strings.TrimSpace(item.resultSummary); summary != "" {
|
||||
label := "摘要"
|
||||
if statusText == "受阻" {
|
||||
label = "原因"
|
||||
} else if statusText == "有进展" {
|
||||
label = "进度"
|
||||
}
|
||||
sb.WriteString("," + label + ":" + summary)
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
return strings.TrimSpace(sb.String())
|
||||
}
|
||||
|
||||
func shouldFlushTelegramStreamSnapshot(s string) bool {
|
||||
s = strings.TrimRight(s, " \t")
|
||||
if s == "" {
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"clawgo/pkg/bus"
|
||||
)
|
||||
@@ -66,3 +68,92 @@ func TestPrepareOutboundSubagentNoReplyFallbackWithMissingOrigin(t *testing.T) {
|
||||
t.Fatalf("unexpected fallback content: %q", outbound.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessSystemMessageSubagentBlockedQueuedIntoDigest(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := &AgentLoop{
|
||||
bus: msgBus,
|
||||
subagentDigestDelay: 10 * time.Millisecond,
|
||||
subagentDigests: map[string]*subagentDigestState{},
|
||||
}
|
||||
out, err := al.processSystemMessage(context.Background(), bus.InboundMessage{
|
||||
Channel: "system",
|
||||
SenderID: "subagent:subagent-3",
|
||||
ChatID: "telegram:9527",
|
||||
Content: "Subagent update\nagent: coder\nrun: subagent-3\nstatus: blocked\nreason: blocked\ntask: 修复登录\nsummary: rate limit",
|
||||
Metadata: map[string]string{
|
||||
"trigger": "subagent",
|
||||
"agent_id": "coder",
|
||||
"status": "failed",
|
||||
"notify_reason": "blocked",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if out != "" {
|
||||
t.Fatalf("expected queued digest with no immediate output, got %q", out)
|
||||
}
|
||||
al.flushDueSubagentDigests(time.Now().Add(20 * time.Millisecond))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
outbound, ok := msgBus.SubscribeOutbound(ctx)
|
||||
if !ok {
|
||||
t.Fatalf("expected outbound digest")
|
||||
}
|
||||
if !strings.Contains(outbound.Content, "阶段总结") || !strings.Contains(outbound.Content, "coder") || !strings.Contains(outbound.Content, "受阻") {
|
||||
t.Fatalf("unexpected digest content: %q", outbound.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessSystemMessageSubagentDigestMergesMultipleUpdates(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := &AgentLoop{
|
||||
bus: msgBus,
|
||||
subagentDigestDelay: 10 * time.Millisecond,
|
||||
subagentDigests: map[string]*subagentDigestState{},
|
||||
}
|
||||
first := bus.InboundMessage{
|
||||
Channel: "system",
|
||||
SenderID: "subagent:subagent-7",
|
||||
ChatID: "telegram:9527",
|
||||
Content: "Subagent update\nagent: tester\nrun: subagent-7\nstatus: completed\nreason: final\ntask: 回归测试\nsummary: 所有测试通过",
|
||||
Metadata: map[string]string{
|
||||
"trigger": "subagent",
|
||||
"agent_id": "tester",
|
||||
"status": "completed",
|
||||
"notify_reason": "final",
|
||||
},
|
||||
}
|
||||
second := bus.InboundMessage{
|
||||
Channel: "system",
|
||||
SenderID: "subagent:subagent-8",
|
||||
ChatID: "telegram:9527",
|
||||
Content: "Subagent update\nagent: coder\nrun: subagent-8\nstatus: completed\nreason: final\ntask: 修复登录\nsummary: 接口已联调",
|
||||
Metadata: map[string]string{
|
||||
"trigger": "subagent",
|
||||
"agent_id": "coder",
|
||||
"status": "completed",
|
||||
"notify_reason": "final",
|
||||
},
|
||||
}
|
||||
if out, err := al.processSystemMessage(context.Background(), first); err != nil || out != "" {
|
||||
t.Fatalf("unexpected first result out=%q err=%v", out, err)
|
||||
}
|
||||
if out, err := al.processSystemMessage(context.Background(), second); err != nil || out != "" {
|
||||
t.Fatalf("unexpected second result out=%q err=%v", out, err)
|
||||
}
|
||||
al.flushDueSubagentDigests(time.Now().Add(20 * time.Millisecond))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
outbound, ok := msgBus.SubscribeOutbound(ctx)
|
||||
if !ok {
|
||||
t.Fatalf("expected merged outbound digest")
|
||||
}
|
||||
if strings.Count(outbound.Content, "\n- ") != 2 {
|
||||
t.Fatalf("expected two digest lines, got: %q", outbound.Content)
|
||||
}
|
||||
if !strings.Contains(outbound.Content, "完成 2") {
|
||||
t.Fatalf("expected aggregate completion count, got: %q", outbound.Content)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,10 +38,11 @@ func (al *AgentLoop) maybeAutoRoute(ctx context.Context, msg bus.InboundMessage)
|
||||
waitCtx, cancel := context.WithTimeout(ctx, time.Duration(waitTimeout)*time.Second)
|
||||
defer cancel()
|
||||
task, err := al.subagentRouter.DispatchTask(waitCtx, tools.RouterDispatchRequest{
|
||||
Task: taskText,
|
||||
AgentID: agentID,
|
||||
OriginChannel: msg.Channel,
|
||||
OriginChatID: msg.ChatID,
|
||||
Task: taskText,
|
||||
AgentID: agentID,
|
||||
NotifyMainPolicy: "internal_only",
|
||||
OriginChannel: msg.Channel,
|
||||
OriginChatID: msg.ChatID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", true, err
|
||||
|
||||
@@ -76,20 +76,21 @@ func (al *AgentLoop) HandleSubagentRuntime(ctx context.Context, action string, a
|
||||
return nil, fmt.Errorf("task is required")
|
||||
}
|
||||
task, err := router.DispatchTask(ctx, tools.RouterDispatchRequest{
|
||||
Task: taskInput,
|
||||
Label: runtimeStringArg(args, "label"),
|
||||
Role: runtimeStringArg(args, "role"),
|
||||
AgentID: runtimeStringArg(args, "agent_id"),
|
||||
ThreadID: runtimeStringArg(args, "thread_id"),
|
||||
CorrelationID: runtimeStringArg(args, "correlation_id"),
|
||||
ParentRunID: runtimeStringArg(args, "parent_run_id"),
|
||||
OriginChannel: fallbackString(runtimeStringArg(args, "channel"), "webui"),
|
||||
OriginChatID: fallbackString(runtimeStringArg(args, "chat_id"), "webui"),
|
||||
MaxRetries: runtimeIntArg(args, "max_retries", 0),
|
||||
RetryBackoff: runtimeIntArg(args, "retry_backoff_ms", 0),
|
||||
TimeoutSec: runtimeIntArg(args, "timeout_sec", 0),
|
||||
MaxTaskChars: runtimeIntArg(args, "max_task_chars", 0),
|
||||
MaxResultChars: runtimeIntArg(args, "max_result_chars", 0),
|
||||
Task: taskInput,
|
||||
Label: runtimeStringArg(args, "label"),
|
||||
Role: runtimeStringArg(args, "role"),
|
||||
AgentID: runtimeStringArg(args, "agent_id"),
|
||||
NotifyMainPolicy: "internal_only",
|
||||
ThreadID: runtimeStringArg(args, "thread_id"),
|
||||
CorrelationID: runtimeStringArg(args, "correlation_id"),
|
||||
ParentRunID: runtimeStringArg(args, "parent_run_id"),
|
||||
OriginChannel: fallbackString(runtimeStringArg(args, "channel"), "webui"),
|
||||
OriginChatID: fallbackString(runtimeStringArg(args, "chat_id"), "webui"),
|
||||
MaxRetries: runtimeIntArg(args, "max_retries", 0),
|
||||
RetryBackoff: runtimeIntArg(args, "retry_backoff_ms", 0),
|
||||
TimeoutSec: runtimeIntArg(args, "timeout_sec", 0),
|
||||
MaxTaskChars: runtimeIntArg(args, "max_task_chars", 0),
|
||||
MaxResultChars: runtimeIntArg(args, "max_result_chars", 0),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -12,9 +12,11 @@ import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"clawgo/pkg/bus"
|
||||
"clawgo/pkg/ekg"
|
||||
"clawgo/pkg/providers"
|
||||
"clawgo/pkg/scheduling"
|
||||
)
|
||||
|
||||
@@ -35,18 +37,23 @@ type plannedTaskResult struct {
|
||||
var reLeadingNumber = regexp.MustCompile(`^\d+[\.)、]\s*`)
|
||||
|
||||
func (al *AgentLoop) processPlannedMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
|
||||
tasks := al.planSessionTasks(msg)
|
||||
tasks := al.planSessionTasks(ctx, msg)
|
||||
if len(tasks) <= 1 {
|
||||
return al.processMessage(ctx, msg)
|
||||
}
|
||||
return al.runPlannedTasks(ctx, msg, tasks)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) planSessionTasks(msg bus.InboundMessage) []plannedTask {
|
||||
func (al *AgentLoop) planSessionTasks(ctx context.Context, msg bus.InboundMessage) []plannedTask {
|
||||
base := strings.TrimSpace(msg.Content)
|
||||
if base == "" {
|
||||
return nil
|
||||
}
|
||||
if msg.Metadata != nil {
|
||||
if planningDisabled(msg.Metadata["disable_planning"]) {
|
||||
return []plannedTask{{Index: 1, Content: base, ResourceKeys: scheduling.DeriveResourceKeys(base)}}
|
||||
}
|
||||
}
|
||||
if msg.Channel == "system" || msg.Channel == "internal" {
|
||||
return []plannedTask{{Index: 1, Content: base, ResourceKeys: scheduling.DeriveResourceKeys(base)}}
|
||||
}
|
||||
@@ -63,6 +70,9 @@ func (al *AgentLoop) planSessionTasks(msg bus.InboundMessage) []plannedTask {
|
||||
if len(segments) <= 1 {
|
||||
return []plannedTask{{Index: 1, Content: base, ResourceKeys: scheduling.DeriveResourceKeys(base)}}
|
||||
}
|
||||
if refined, ok := al.inferPlannedSegments(ctx, base, segments); ok {
|
||||
segments = refined
|
||||
}
|
||||
|
||||
out := make([]plannedTask, 0, len(segments))
|
||||
for i, seg := range segments {
|
||||
@@ -86,6 +96,125 @@ func (al *AgentLoop) planSessionTasks(msg bus.InboundMessage) []plannedTask {
|
||||
return out
|
||||
}
|
||||
|
||||
type plannerDecision struct {
|
||||
ShouldSplit bool `json:"should_split"`
|
||||
Tasks []string `json:"tasks"`
|
||||
}
|
||||
|
||||
func (al *AgentLoop) inferPlannedSegments(ctx context.Context, content string, candidates []string) ([]string, bool) {
|
||||
if al == nil || al.provider == nil {
|
||||
return nil, false
|
||||
}
|
||||
content = strings.TrimSpace(content)
|
||||
if content == "" || len(candidates) <= 1 {
|
||||
return nil, false
|
||||
}
|
||||
plannerCtx := ctx
|
||||
if plannerCtx == nil {
|
||||
plannerCtx = context.Background()
|
||||
}
|
||||
var cancel context.CancelFunc
|
||||
plannerCtx, cancel = context.WithTimeout(plannerCtx, 8*time.Second)
|
||||
defer cancel()
|
||||
|
||||
previewCandidates := candidates
|
||||
if len(previewCandidates) > 12 {
|
||||
previewCandidates = previewCandidates[:12]
|
||||
}
|
||||
var candidateList strings.Builder
|
||||
for i, item := range previewCandidates {
|
||||
candidateList.WriteString(fmt.Sprintf("%d. %s\n", i+1, strings.TrimSpace(item)))
|
||||
}
|
||||
|
||||
resp, err := al.provider.Chat(plannerCtx, []providers.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "Decide whether the request should stay as one task or be split into a small number of high-level task groups. " +
|
||||
"Default to no split. Only split when the request contains multiple independent deliverables, roles, or workstreams. " +
|
||||
"Never split into fine-grained execution steps. Merge related steps. Return strict JSON only: " +
|
||||
`{"should_split":true|false,"tasks":["..."]}` +
|
||||
" If should_split is false, tasks must be empty. If true, tasks must contain 2 to 8 concise high-level task groups.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf("Original request:\n%s\n\nRule-based candidate segments (%d shown of %d):\n%s",
|
||||
content,
|
||||
len(previewCandidates),
|
||||
len(candidates),
|
||||
strings.TrimSpace(candidateList.String()),
|
||||
),
|
||||
},
|
||||
}, nil, al.provider.GetDefaultModel(), map[string]interface{}{
|
||||
"max_tokens": 256,
|
||||
})
|
||||
if err != nil || resp == nil {
|
||||
return nil, false
|
||||
}
|
||||
decision, ok := parsePlannerDecision(resp.Content)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if !decision.ShouldSplit {
|
||||
return []string{content}, true
|
||||
}
|
||||
tasks := sanitizePlannerTasks(decision.Tasks)
|
||||
if len(tasks) < 2 || len(tasks) > 8 {
|
||||
return []string{content}, true
|
||||
}
|
||||
return tasks, true
|
||||
}
|
||||
|
||||
func parsePlannerDecision(raw string) (plannerDecision, bool) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return plannerDecision{}, false
|
||||
}
|
||||
if fenced := extractJSONObject(raw); fenced != "" {
|
||||
raw = fenced
|
||||
}
|
||||
var out plannerDecision
|
||||
if err := json.Unmarshal([]byte(raw), &out); err != nil {
|
||||
return plannerDecision{}, false
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
|
||||
func extractJSONObject(raw string) string {
|
||||
start := strings.Index(raw, "{")
|
||||
end := strings.LastIndex(raw, "}")
|
||||
if start < 0 || end <= start {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(raw[start : end+1])
|
||||
}
|
||||
|
||||
func sanitizePlannerTasks(items []string) []string {
|
||||
out := make([]string, 0, len(items))
|
||||
seen := map[string]struct{}{}
|
||||
for _, item := range items {
|
||||
t := strings.TrimSpace(reLeadingNumber.ReplaceAllString(strings.TrimSpace(item), ""))
|
||||
if t == "" {
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(t)
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, t)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func planningDisabled(v string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(v)) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func splitPlannedSegments(content string) []string {
|
||||
lines := strings.Split(content, "\n")
|
||||
bullet := make([]string, 0, len(lines))
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
package agent
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"clawgo/pkg/bus"
|
||||
"clawgo/pkg/providers"
|
||||
)
|
||||
|
||||
func TestSplitPlannedSegmentsDoesNotSplitPlainNewlines(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -31,3 +37,75 @@ func TestSplitPlannedSegmentsStillSplitsSemicolons(t *testing.T) {
|
||||
t.Fatalf("expected 2 segments, got %d: %#v", len(got), got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlanSessionTasksDisablePlanningMetadataSkipsBulletSplit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
al := &AgentLoop{}
|
||||
msg := bus.InboundMessage{
|
||||
Content: "Role Profile Policy:\n1. 先做代码审查\n2. 再执行测试\n\nTask:\n修复登录流程",
|
||||
Metadata: map[string]string{
|
||||
"disable_planning": "true",
|
||||
},
|
||||
}
|
||||
|
||||
got := al.planSessionTasks(context.Background(), msg)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 segment, got %d: %#v", len(got), got)
|
||||
}
|
||||
if got[0].Content != msg.Content {
|
||||
t.Fatalf("expected original content to be preserved, got: %#v", got[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlanSessionTasksUsesAIPlannerDecisionToAvoidOversplitting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
al := &AgentLoop{
|
||||
provider: plannerStubProvider{content: `{"should_split":false,"tasks":[]}`},
|
||||
}
|
||||
msg := bus.InboundMessage{
|
||||
Content: "1. 调研\n2. 写方案\n3. 设计数据库\n4. 写接口\n5. 写前端\n6. 写测试",
|
||||
}
|
||||
|
||||
got := al.planSessionTasks(context.Background(), msg)
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected AI planner to keep one task, got %d: %#v", len(got), got)
|
||||
}
|
||||
if got[0].Content != msg.Content {
|
||||
t.Fatalf("expected original content, got %#v", got[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlanSessionTasksUsesAIPlannerDecisionToSplitIntoHighLevelGroups(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
al := &AgentLoop{
|
||||
provider: plannerStubProvider{content: `{"should_split":true,"tasks":["产品方案","研发实现","测试验收"]}`},
|
||||
}
|
||||
msg := bus.InboundMessage{
|
||||
Content: "1. 调研\n2. 写方案\n3. 设计数据库\n4. 写接口\n5. 写前端\n6. 写测试",
|
||||
}
|
||||
|
||||
got := al.planSessionTasks(context.Background(), msg)
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("expected 3 planned tasks, got %d: %#v", len(got), got)
|
||||
}
|
||||
if got[0].Content != "产品方案" || got[1].Content != "研发实现" || got[2].Content != "测试验收" {
|
||||
t.Fatalf("unexpected planned task contents: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
type plannerStubProvider struct {
|
||||
content string
|
||||
err error
|
||||
}
|
||||
|
||||
func (p plannerStubProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
if p.err != nil {
|
||||
return nil, p.err
|
||||
}
|
||||
return &providers.LLMResponse{Content: p.content}, nil
|
||||
}
|
||||
|
||||
func (p plannerStubProvider) GetDefaultModel() string { return "stub" }
|
||||
|
||||
@@ -64,6 +64,11 @@ type Server struct {
|
||||
ekgCacheStamp time.Time
|
||||
ekgCacheSize int64
|
||||
ekgCacheRows []map[string]interface{}
|
||||
liveRuntimeMu sync.Mutex
|
||||
liveRuntimeSubs map[chan []byte]struct{}
|
||||
liveRuntimeOn bool
|
||||
liveSubagentMu sync.Mutex
|
||||
liveSubagents map[string]*liveSubagentGroup
|
||||
}
|
||||
|
||||
var nodesWebsocketUpgrader = websocket.Upgrader{
|
||||
@@ -79,12 +84,14 @@ func NewServer(host string, port int, token string, mgr *nodes.Manager) *Server
|
||||
port = 7788
|
||||
}
|
||||
return &Server{
|
||||
addr: fmt.Sprintf("%s:%d", addr, port),
|
||||
token: strings.TrimSpace(token),
|
||||
mgr: mgr,
|
||||
nodeConnIDs: map[string]string{},
|
||||
nodeSockets: map[string]*nodeSocketConn{},
|
||||
artifactStats: map[string]interface{}{},
|
||||
addr: fmt.Sprintf("%s:%d", addr, port),
|
||||
token: strings.TrimSpace(token),
|
||||
mgr: mgr,
|
||||
nodeConnIDs: map[string]string{},
|
||||
nodeSockets: map[string]*nodeSocketConn{},
|
||||
artifactStats: map[string]interface{}{},
|
||||
liveRuntimeSubs: map[chan []byte]struct{}{},
|
||||
liveSubagents: map[string]*liveSubagentGroup{},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,6 +101,13 @@ type nodeSocketConn struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type liveSubagentGroup struct {
|
||||
taskID string
|
||||
previewTaskID string
|
||||
subs map[chan []byte]struct{}
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
func (c *nodeSocketConn) Send(msg nodes.WireMessage) error {
|
||||
if c == nil || c.conn == nil {
|
||||
return fmt.Errorf("node websocket unavailable")
|
||||
@@ -104,6 +118,168 @@ func (c *nodeSocketConn) Send(msg nodes.WireMessage) error {
|
||||
return c.conn.WriteJSON(msg)
|
||||
}
|
||||
|
||||
func publishLiveSnapshot(subs map[chan []byte]struct{}, payload []byte) {
|
||||
for ch := range subs {
|
||||
select {
|
||||
case ch <- payload:
|
||||
default:
|
||||
select {
|
||||
case <-ch:
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case ch <- payload:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) subscribeRuntimeLive(ctx context.Context) chan []byte {
|
||||
ch := make(chan []byte, 1)
|
||||
s.liveRuntimeMu.Lock()
|
||||
s.liveRuntimeSubs[ch] = struct{}{}
|
||||
start := !s.liveRuntimeOn
|
||||
if start {
|
||||
s.liveRuntimeOn = true
|
||||
}
|
||||
s.liveRuntimeMu.Unlock()
|
||||
if start {
|
||||
go s.runtimeLiveLoop()
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
s.unsubscribeRuntimeLive(ch)
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
func (s *Server) unsubscribeRuntimeLive(ch chan []byte) {
|
||||
s.liveRuntimeMu.Lock()
|
||||
delete(s.liveRuntimeSubs, ch)
|
||||
s.liveRuntimeMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) runtimeLiveLoop() {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
if !s.publishRuntimeSnapshot(context.Background()) {
|
||||
s.liveRuntimeMu.Lock()
|
||||
if len(s.liveRuntimeSubs) == 0 {
|
||||
s.liveRuntimeOn = false
|
||||
s.liveRuntimeMu.Unlock()
|
||||
return
|
||||
}
|
||||
s.liveRuntimeMu.Unlock()
|
||||
}
|
||||
<-ticker.C
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) publishRuntimeSnapshot(ctx context.Context) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
payload := map[string]interface{}{
|
||||
"ok": true,
|
||||
"type": "runtime_snapshot",
|
||||
"snapshot": s.buildWebUIRuntimeSnapshot(ctx),
|
||||
}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
s.liveRuntimeMu.Lock()
|
||||
defer s.liveRuntimeMu.Unlock()
|
||||
if len(s.liveRuntimeSubs) == 0 {
|
||||
return false
|
||||
}
|
||||
publishLiveSnapshot(s.liveRuntimeSubs, data)
|
||||
return true
|
||||
}
|
||||
|
||||
func buildSubagentLiveKey(taskID, previewTaskID string) string {
|
||||
return strings.TrimSpace(taskID) + "\x00" + strings.TrimSpace(previewTaskID)
|
||||
}
|
||||
|
||||
func (s *Server) subscribeSubagentLive(ctx context.Context, taskID, previewTaskID string) chan []byte {
|
||||
ch := make(chan []byte, 1)
|
||||
key := buildSubagentLiveKey(taskID, previewTaskID)
|
||||
s.liveSubagentMu.Lock()
|
||||
group := s.liveSubagents[key]
|
||||
if group == nil {
|
||||
group = &liveSubagentGroup{
|
||||
taskID: strings.TrimSpace(taskID),
|
||||
previewTaskID: strings.TrimSpace(previewTaskID),
|
||||
subs: map[chan []byte]struct{}{},
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
s.liveSubagents[key] = group
|
||||
go s.subagentLiveLoop(key, group)
|
||||
}
|
||||
group.subs[ch] = struct{}{}
|
||||
s.liveSubagentMu.Unlock()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
s.unsubscribeSubagentLive(key, ch)
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
func (s *Server) unsubscribeSubagentLive(key string, ch chan []byte) {
|
||||
s.liveSubagentMu.Lock()
|
||||
group := s.liveSubagents[key]
|
||||
if group == nil {
|
||||
s.liveSubagentMu.Unlock()
|
||||
return
|
||||
}
|
||||
delete(group.subs, ch)
|
||||
if len(group.subs) == 0 {
|
||||
delete(s.liveSubagents, key)
|
||||
close(group.stopCh)
|
||||
}
|
||||
s.liveSubagentMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) subagentLiveLoop(key string, group *liveSubagentGroup) {
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
if !s.publishSubagentLiveSnapshot(context.Background(), key, group.taskID, group.previewTaskID) {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-group.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) publishSubagentLiveSnapshot(ctx context.Context, key, taskID, previewTaskID string) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
payload := map[string]interface{}{
|
||||
"ok": true,
|
||||
"type": "subagents_live",
|
||||
"payload": s.buildSubagentsLivePayload(ctx, taskID, previewTaskID),
|
||||
}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
s.liveSubagentMu.Lock()
|
||||
defer s.liveSubagentMu.Unlock()
|
||||
group := s.liveSubagents[key]
|
||||
if group == nil || len(group.subs) == 0 {
|
||||
return false
|
||||
}
|
||||
publishLiveSnapshot(group.subs, data)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) SetConfigPath(path string) { s.configPath = strings.TrimSpace(path) }
|
||||
func (s *Server) SetWorkspacePath(path string) { s.workspacePath = strings.TrimSpace(path) }
|
||||
func (s *Server) SetLogFilePath(path string) { s.logFilePath = strings.TrimSpace(path) }
|
||||
@@ -977,28 +1153,23 @@ func (s *Server) handleWebUIRuntime(w http.ResponseWriter, r *http.Request) {
|
||||
defer conn.Close()
|
||||
|
||||
ctx := r.Context()
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
sendSnapshot := func() error {
|
||||
payload := map[string]interface{}{
|
||||
"ok": true,
|
||||
"type": "runtime_snapshot",
|
||||
"snapshot": s.buildWebUIRuntimeSnapshot(ctx),
|
||||
}
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
return conn.WriteJSON(payload)
|
||||
sub := s.subscribeRuntimeLive(ctx)
|
||||
initial := map[string]interface{}{
|
||||
"ok": true,
|
||||
"type": "runtime_snapshot",
|
||||
"snapshot": s.buildWebUIRuntimeSnapshot(ctx),
|
||||
}
|
||||
|
||||
if err := sendSnapshot(); err != nil {
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := conn.WriteJSON(initial); err != nil {
|
||||
return
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := sendSnapshot(); err != nil {
|
||||
case payload := <-sub:
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := conn.WriteMessage(websocket.TextMessage, payload); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -4245,28 +4416,23 @@ func (s *Server) handleWebUISubagentsRuntimeLive(w http.ResponseWriter, r *http.
|
||||
ctx := r.Context()
|
||||
taskID := strings.TrimSpace(r.URL.Query().Get("task_id"))
|
||||
previewTaskID := strings.TrimSpace(r.URL.Query().Get("preview_task_id"))
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
sendSnapshot := func() error {
|
||||
payload := map[string]interface{}{
|
||||
"ok": true,
|
||||
"type": "subagents_live",
|
||||
"payload": s.buildSubagentsLivePayload(ctx, taskID, previewTaskID),
|
||||
}
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
return conn.WriteJSON(payload)
|
||||
sub := s.subscribeSubagentLive(ctx, taskID, previewTaskID)
|
||||
initial := map[string]interface{}{
|
||||
"ok": true,
|
||||
"type": "subagents_live",
|
||||
"payload": s.buildSubagentsLivePayload(ctx, taskID, previewTaskID),
|
||||
}
|
||||
|
||||
if err := sendSnapshot(); err != nil {
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := conn.WriteJSON(initial); err != nil {
|
||||
return
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := sendSnapshot(); err != nil {
|
||||
case payload := <-sub:
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := conn.WriteMessage(websocket.TextMessage, payload); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,6 +53,7 @@ type SubagentTask struct {
|
||||
type SubagentManager struct {
|
||||
tasks map[string]*SubagentTask
|
||||
cancelFuncs map[string]context.CancelFunc
|
||||
waiters map[string]map[chan struct{}]struct{}
|
||||
recoverableTaskIDs []string
|
||||
archiveAfterMinute int64
|
||||
mu sync.RWMutex
|
||||
@@ -92,6 +93,7 @@ func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *b
|
||||
mgr := &SubagentManager{
|
||||
tasks: make(map[string]*SubagentTask),
|
||||
cancelFuncs: make(map[string]context.CancelFunc),
|
||||
waiters: make(map[string]map[chan struct{}]struct{}),
|
||||
archiveAfterMinute: 60,
|
||||
provider: provider,
|
||||
bus: bus,
|
||||
@@ -356,6 +358,7 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
|
||||
CreatedAt: task.Updated,
|
||||
})
|
||||
sm.persistTaskLocked(task, "completed", task.Result)
|
||||
sm.notifyTaskWaitersLocked(task.ID)
|
||||
} else {
|
||||
task.Status = "completed"
|
||||
task.Result = applySubagentResultQuota(result, task.MaxResultChars)
|
||||
@@ -373,6 +376,7 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
|
||||
CreatedAt: task.Updated,
|
||||
})
|
||||
sm.persistTaskLocked(task, "completed", task.Result)
|
||||
sm.notifyTaskWaitersLocked(task.ID)
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
|
||||
@@ -790,6 +794,7 @@ func (sm *SubagentManager) KillTask(taskID string) bool {
|
||||
t.WaitingReply = false
|
||||
t.Updated = time.Now().UnixMilli()
|
||||
sm.persistTaskLocked(t, "killed", "")
|
||||
sm.notifyTaskWaitersLocked(taskID)
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1013,6 +1018,96 @@ func (sm *SubagentManager) persistTaskLocked(task *SubagentTask, eventType, mess
|
||||
})
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) WaitTask(ctx context.Context, taskID string) (*SubagentTask, bool, error) {
|
||||
if sm == nil {
|
||||
return nil, false, fmt.Errorf("subagent manager not available")
|
||||
}
|
||||
taskID = strings.TrimSpace(taskID)
|
||||
if taskID == "" {
|
||||
return nil, false, fmt.Errorf("task id is required")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
ch := make(chan struct{}, 1)
|
||||
sm.mu.Lock()
|
||||
sm.pruneArchivedLocked()
|
||||
task, ok := sm.tasks[taskID]
|
||||
if !ok && sm.runStore != nil {
|
||||
if persisted, found := sm.runStore.Get(taskID); found && persisted != nil {
|
||||
if strings.TrimSpace(persisted.Status) != "running" {
|
||||
sm.mu.Unlock()
|
||||
return persisted, true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if ok && task != nil && strings.TrimSpace(task.Status) != "running" {
|
||||
cp := cloneSubagentTask(task)
|
||||
sm.mu.Unlock()
|
||||
return cp, true, nil
|
||||
}
|
||||
waiters := sm.waiters[taskID]
|
||||
if waiters == nil {
|
||||
waiters = map[chan struct{}]struct{}{}
|
||||
sm.waiters[taskID] = waiters
|
||||
}
|
||||
waiters[ch] = struct{}{}
|
||||
sm.mu.Unlock()
|
||||
|
||||
defer sm.removeTaskWaiter(taskID, ch)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, false, ctx.Err()
|
||||
case <-ch:
|
||||
sm.mu.Lock()
|
||||
sm.pruneArchivedLocked()
|
||||
task, ok := sm.tasks[taskID]
|
||||
if ok && task != nil && strings.TrimSpace(task.Status) != "running" {
|
||||
cp := cloneSubagentTask(task)
|
||||
sm.mu.Unlock()
|
||||
return cp, true, nil
|
||||
}
|
||||
if !ok && sm.runStore != nil {
|
||||
if persisted, found := sm.runStore.Get(taskID); found && persisted != nil && strings.TrimSpace(persisted.Status) != "running" {
|
||||
sm.mu.Unlock()
|
||||
return persisted, true, nil
|
||||
}
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) removeTaskWaiter(taskID string, ch chan struct{}) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
waiters := sm.waiters[taskID]
|
||||
if len(waiters) == 0 {
|
||||
delete(sm.waiters, taskID)
|
||||
return
|
||||
}
|
||||
delete(waiters, ch)
|
||||
if len(waiters) == 0 {
|
||||
delete(sm.waiters, taskID)
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) notifyTaskWaitersLocked(taskID string) {
|
||||
waiters := sm.waiters[taskID]
|
||||
if len(waiters) == 0 {
|
||||
delete(sm.waiters, taskID)
|
||||
return
|
||||
}
|
||||
for ch := range waiters {
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
delete(sm.waiters, taskID)
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) recordMailboxMessageLocked(task *SubagentTask, msg AgentMessage) {
|
||||
if sm.mailboxStore == nil || task == nil {
|
||||
return
|
||||
|
||||
@@ -8,20 +8,21 @@ import (
|
||||
)
|
||||
|
||||
type RouterDispatchRequest struct {
|
||||
Task string
|
||||
Label string
|
||||
Role string
|
||||
AgentID string
|
||||
ThreadID string
|
||||
CorrelationID string
|
||||
ParentRunID string
|
||||
OriginChannel string
|
||||
OriginChatID string
|
||||
MaxRetries int
|
||||
RetryBackoff int
|
||||
TimeoutSec int
|
||||
MaxTaskChars int
|
||||
MaxResultChars int
|
||||
Task string
|
||||
Label string
|
||||
Role string
|
||||
AgentID string
|
||||
NotifyMainPolicy string
|
||||
ThreadID string
|
||||
CorrelationID string
|
||||
ParentRunID string
|
||||
OriginChannel string
|
||||
OriginChatID string
|
||||
MaxRetries int
|
||||
RetryBackoff int
|
||||
TimeoutSec int
|
||||
MaxTaskChars int
|
||||
MaxResultChars int
|
||||
}
|
||||
|
||||
type RouterReply struct {
|
||||
@@ -46,20 +47,21 @@ func (r *SubagentRouter) DispatchTask(ctx context.Context, req RouterDispatchReq
|
||||
return nil, fmt.Errorf("subagent router is not configured")
|
||||
}
|
||||
task, err := r.manager.SpawnTask(ctx, SubagentSpawnOptions{
|
||||
Task: req.Task,
|
||||
Label: req.Label,
|
||||
Role: req.Role,
|
||||
AgentID: req.AgentID,
|
||||
ThreadID: req.ThreadID,
|
||||
CorrelationID: req.CorrelationID,
|
||||
ParentRunID: req.ParentRunID,
|
||||
OriginChannel: req.OriginChannel,
|
||||
OriginChatID: req.OriginChatID,
|
||||
MaxRetries: req.MaxRetries,
|
||||
RetryBackoff: req.RetryBackoff,
|
||||
TimeoutSec: req.TimeoutSec,
|
||||
MaxTaskChars: req.MaxTaskChars,
|
||||
MaxResultChars: req.MaxResultChars,
|
||||
Task: req.Task,
|
||||
Label: req.Label,
|
||||
Role: req.Role,
|
||||
AgentID: req.AgentID,
|
||||
NotifyMainPolicy: req.NotifyMainPolicy,
|
||||
ThreadID: req.ThreadID,
|
||||
CorrelationID: req.CorrelationID,
|
||||
ParentRunID: req.ParentRunID,
|
||||
OriginChannel: req.OriginChannel,
|
||||
OriginChatID: req.OriginChatID,
|
||||
MaxRetries: req.MaxRetries,
|
||||
RetryBackoff: req.RetryBackoff,
|
||||
TimeoutSec: req.TimeoutSec,
|
||||
MaxTaskChars: req.MaxTaskChars,
|
||||
MaxResultChars: req.MaxResultChars,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -71,33 +73,26 @@ func (r *SubagentRouter) WaitReply(ctx context.Context, taskID string, interval
|
||||
if r == nil || r.manager == nil {
|
||||
return nil, fmt.Errorf("subagent router is not configured")
|
||||
}
|
||||
if interval <= 0 {
|
||||
interval = 100 * time.Millisecond
|
||||
}
|
||||
_ = interval
|
||||
taskID = strings.TrimSpace(taskID)
|
||||
if taskID == "" {
|
||||
return nil, fmt.Errorf("task id is required")
|
||||
}
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
task, ok := r.manager.GetTask(taskID)
|
||||
if ok && task != nil && task.Status != "running" {
|
||||
return &RouterReply{
|
||||
TaskID: task.ID,
|
||||
ThreadID: task.ThreadID,
|
||||
CorrelationID: task.CorrelationID,
|
||||
AgentID: task.AgentID,
|
||||
Status: task.Status,
|
||||
Result: strings.TrimSpace(task.Result),
|
||||
}, nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
task, ok, err := r.manager.WaitTask(ctx, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok || task == nil {
|
||||
return nil, fmt.Errorf("subagent not found")
|
||||
}
|
||||
return &RouterReply{
|
||||
TaskID: task.ID,
|
||||
ThreadID: task.ThreadID,
|
||||
CorrelationID: task.CorrelationID,
|
||||
AgentID: task.AgentID,
|
||||
Status: task.Status,
|
||||
Result: strings.TrimSpace(task.Result),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *SubagentRouter) MergeResults(replies []*RouterReply) string {
|
||||
|
||||
@@ -47,3 +47,29 @@ func TestSubagentRouterMergeResults(t *testing.T) {
|
||||
t.Fatalf("unexpected merged output: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubagentRouterWaitReplyContextCancel(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(nil, workspace, nil)
|
||||
manager.SetRunFunc(func(ctx context.Context, task *SubagentTask) (string, error) {
|
||||
<-ctx.Done()
|
||||
return "", ctx.Err()
|
||||
})
|
||||
router := NewSubagentRouter(manager)
|
||||
|
||||
task, err := router.DispatchTask(context.Background(), RouterDispatchRequest{
|
||||
Task: "long task",
|
||||
AgentID: "coder",
|
||||
OriginChannel: "cli",
|
||||
OriginChatID: "direct",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("dispatch failed: %v", err)
|
||||
}
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
if _, err := router.WaitReply(waitCtx, task.ID, 20*time.Millisecond); err == nil {
|
||||
t.Fatalf("expected context cancellation error")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,8 @@ const (
|
||||
maxWorldCycle = 60 * time.Second
|
||||
)
|
||||
|
||||
const GlobalWatchdogTick = watchdogTick
|
||||
|
||||
var ErrCommandNoProgress = errors.New("command no progress across tick rounds")
|
||||
var ErrTaskWatchdogTimeout = errors.New("task watchdog timeout exceeded")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user