mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-29 22:27:35 +08:00
Optimize agent planning and subagent runtime
This commit is contained in:
@@ -53,6 +53,7 @@ type SubagentTask struct {
|
||||
type SubagentManager struct {
|
||||
tasks map[string]*SubagentTask
|
||||
cancelFuncs map[string]context.CancelFunc
|
||||
waiters map[string]map[chan struct{}]struct{}
|
||||
recoverableTaskIDs []string
|
||||
archiveAfterMinute int64
|
||||
mu sync.RWMutex
|
||||
@@ -92,6 +93,7 @@ func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *b
|
||||
mgr := &SubagentManager{
|
||||
tasks: make(map[string]*SubagentTask),
|
||||
cancelFuncs: make(map[string]context.CancelFunc),
|
||||
waiters: make(map[string]map[chan struct{}]struct{}),
|
||||
archiveAfterMinute: 60,
|
||||
provider: provider,
|
||||
bus: bus,
|
||||
@@ -356,6 +358,7 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
|
||||
CreatedAt: task.Updated,
|
||||
})
|
||||
sm.persistTaskLocked(task, "completed", task.Result)
|
||||
sm.notifyTaskWaitersLocked(task.ID)
|
||||
} else {
|
||||
task.Status = "completed"
|
||||
task.Result = applySubagentResultQuota(result, task.MaxResultChars)
|
||||
@@ -373,6 +376,7 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
|
||||
CreatedAt: task.Updated,
|
||||
})
|
||||
sm.persistTaskLocked(task, "completed", task.Result)
|
||||
sm.notifyTaskWaitersLocked(task.ID)
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
|
||||
@@ -790,6 +794,7 @@ func (sm *SubagentManager) KillTask(taskID string) bool {
|
||||
t.WaitingReply = false
|
||||
t.Updated = time.Now().UnixMilli()
|
||||
sm.persistTaskLocked(t, "killed", "")
|
||||
sm.notifyTaskWaitersLocked(taskID)
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1013,6 +1018,96 @@ func (sm *SubagentManager) persistTaskLocked(task *SubagentTask, eventType, mess
|
||||
})
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) WaitTask(ctx context.Context, taskID string) (*SubagentTask, bool, error) {
|
||||
if sm == nil {
|
||||
return nil, false, fmt.Errorf("subagent manager not available")
|
||||
}
|
||||
taskID = strings.TrimSpace(taskID)
|
||||
if taskID == "" {
|
||||
return nil, false, fmt.Errorf("task id is required")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
ch := make(chan struct{}, 1)
|
||||
sm.mu.Lock()
|
||||
sm.pruneArchivedLocked()
|
||||
task, ok := sm.tasks[taskID]
|
||||
if !ok && sm.runStore != nil {
|
||||
if persisted, found := sm.runStore.Get(taskID); found && persisted != nil {
|
||||
if strings.TrimSpace(persisted.Status) != "running" {
|
||||
sm.mu.Unlock()
|
||||
return persisted, true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if ok && task != nil && strings.TrimSpace(task.Status) != "running" {
|
||||
cp := cloneSubagentTask(task)
|
||||
sm.mu.Unlock()
|
||||
return cp, true, nil
|
||||
}
|
||||
waiters := sm.waiters[taskID]
|
||||
if waiters == nil {
|
||||
waiters = map[chan struct{}]struct{}{}
|
||||
sm.waiters[taskID] = waiters
|
||||
}
|
||||
waiters[ch] = struct{}{}
|
||||
sm.mu.Unlock()
|
||||
|
||||
defer sm.removeTaskWaiter(taskID, ch)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, false, ctx.Err()
|
||||
case <-ch:
|
||||
sm.mu.Lock()
|
||||
sm.pruneArchivedLocked()
|
||||
task, ok := sm.tasks[taskID]
|
||||
if ok && task != nil && strings.TrimSpace(task.Status) != "running" {
|
||||
cp := cloneSubagentTask(task)
|
||||
sm.mu.Unlock()
|
||||
return cp, true, nil
|
||||
}
|
||||
if !ok && sm.runStore != nil {
|
||||
if persisted, found := sm.runStore.Get(taskID); found && persisted != nil && strings.TrimSpace(persisted.Status) != "running" {
|
||||
sm.mu.Unlock()
|
||||
return persisted, true, nil
|
||||
}
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) removeTaskWaiter(taskID string, ch chan struct{}) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
waiters := sm.waiters[taskID]
|
||||
if len(waiters) == 0 {
|
||||
delete(sm.waiters, taskID)
|
||||
return
|
||||
}
|
||||
delete(waiters, ch)
|
||||
if len(waiters) == 0 {
|
||||
delete(sm.waiters, taskID)
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) notifyTaskWaitersLocked(taskID string) {
|
||||
waiters := sm.waiters[taskID]
|
||||
if len(waiters) == 0 {
|
||||
delete(sm.waiters, taskID)
|
||||
return
|
||||
}
|
||||
for ch := range waiters {
|
||||
select {
|
||||
case ch <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
delete(sm.waiters, taskID)
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) recordMailboxMessageLocked(task *SubagentTask, msg AgentMessage) {
|
||||
if sm.mailboxStore == nil || task == nil {
|
||||
return
|
||||
|
||||
@@ -8,20 +8,21 @@ import (
|
||||
)
|
||||
|
||||
type RouterDispatchRequest struct {
|
||||
Task string
|
||||
Label string
|
||||
Role string
|
||||
AgentID string
|
||||
ThreadID string
|
||||
CorrelationID string
|
||||
ParentRunID string
|
||||
OriginChannel string
|
||||
OriginChatID string
|
||||
MaxRetries int
|
||||
RetryBackoff int
|
||||
TimeoutSec int
|
||||
MaxTaskChars int
|
||||
MaxResultChars int
|
||||
Task string
|
||||
Label string
|
||||
Role string
|
||||
AgentID string
|
||||
NotifyMainPolicy string
|
||||
ThreadID string
|
||||
CorrelationID string
|
||||
ParentRunID string
|
||||
OriginChannel string
|
||||
OriginChatID string
|
||||
MaxRetries int
|
||||
RetryBackoff int
|
||||
TimeoutSec int
|
||||
MaxTaskChars int
|
||||
MaxResultChars int
|
||||
}
|
||||
|
||||
type RouterReply struct {
|
||||
@@ -46,20 +47,21 @@ func (r *SubagentRouter) DispatchTask(ctx context.Context, req RouterDispatchReq
|
||||
return nil, fmt.Errorf("subagent router is not configured")
|
||||
}
|
||||
task, err := r.manager.SpawnTask(ctx, SubagentSpawnOptions{
|
||||
Task: req.Task,
|
||||
Label: req.Label,
|
||||
Role: req.Role,
|
||||
AgentID: req.AgentID,
|
||||
ThreadID: req.ThreadID,
|
||||
CorrelationID: req.CorrelationID,
|
||||
ParentRunID: req.ParentRunID,
|
||||
OriginChannel: req.OriginChannel,
|
||||
OriginChatID: req.OriginChatID,
|
||||
MaxRetries: req.MaxRetries,
|
||||
RetryBackoff: req.RetryBackoff,
|
||||
TimeoutSec: req.TimeoutSec,
|
||||
MaxTaskChars: req.MaxTaskChars,
|
||||
MaxResultChars: req.MaxResultChars,
|
||||
Task: req.Task,
|
||||
Label: req.Label,
|
||||
Role: req.Role,
|
||||
AgentID: req.AgentID,
|
||||
NotifyMainPolicy: req.NotifyMainPolicy,
|
||||
ThreadID: req.ThreadID,
|
||||
CorrelationID: req.CorrelationID,
|
||||
ParentRunID: req.ParentRunID,
|
||||
OriginChannel: req.OriginChannel,
|
||||
OriginChatID: req.OriginChatID,
|
||||
MaxRetries: req.MaxRetries,
|
||||
RetryBackoff: req.RetryBackoff,
|
||||
TimeoutSec: req.TimeoutSec,
|
||||
MaxTaskChars: req.MaxTaskChars,
|
||||
MaxResultChars: req.MaxResultChars,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -71,33 +73,26 @@ func (r *SubagentRouter) WaitReply(ctx context.Context, taskID string, interval
|
||||
if r == nil || r.manager == nil {
|
||||
return nil, fmt.Errorf("subagent router is not configured")
|
||||
}
|
||||
if interval <= 0 {
|
||||
interval = 100 * time.Millisecond
|
||||
}
|
||||
_ = interval
|
||||
taskID = strings.TrimSpace(taskID)
|
||||
if taskID == "" {
|
||||
return nil, fmt.Errorf("task id is required")
|
||||
}
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
task, ok := r.manager.GetTask(taskID)
|
||||
if ok && task != nil && task.Status != "running" {
|
||||
return &RouterReply{
|
||||
TaskID: task.ID,
|
||||
ThreadID: task.ThreadID,
|
||||
CorrelationID: task.CorrelationID,
|
||||
AgentID: task.AgentID,
|
||||
Status: task.Status,
|
||||
Result: strings.TrimSpace(task.Result),
|
||||
}, nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
task, ok, err := r.manager.WaitTask(ctx, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok || task == nil {
|
||||
return nil, fmt.Errorf("subagent not found")
|
||||
}
|
||||
return &RouterReply{
|
||||
TaskID: task.ID,
|
||||
ThreadID: task.ThreadID,
|
||||
CorrelationID: task.CorrelationID,
|
||||
AgentID: task.AgentID,
|
||||
Status: task.Status,
|
||||
Result: strings.TrimSpace(task.Result),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *SubagentRouter) MergeResults(replies []*RouterReply) string {
|
||||
|
||||
@@ -47,3 +47,29 @@ func TestSubagentRouterMergeResults(t *testing.T) {
|
||||
t.Fatalf("unexpected merged output: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubagentRouterWaitReplyContextCancel(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(nil, workspace, nil)
|
||||
manager.SetRunFunc(func(ctx context.Context, task *SubagentTask) (string, error) {
|
||||
<-ctx.Done()
|
||||
return "", ctx.Err()
|
||||
})
|
||||
router := NewSubagentRouter(manager)
|
||||
|
||||
task, err := router.DispatchTask(context.Background(), RouterDispatchRequest{
|
||||
Task: "long task",
|
||||
AgentID: "coder",
|
||||
OriginChannel: "cli",
|
||||
OriginChatID: "direct",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("dispatch failed: %v", err)
|
||||
}
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
if _, err := router.WaitReply(waitCtx, task.ID, 20*time.Millisecond); err == nil {
|
||||
t.Fatalf("expected context cancellation error")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,8 @@ const (
|
||||
maxWorldCycle = 60 * time.Second
|
||||
)
|
||||
|
||||
const GlobalWatchdogTick = watchdogTick
|
||||
|
||||
var ErrCommandNoProgress = errors.New("command no progress across tick rounds")
|
||||
var ErrTaskWatchdogTimeout = errors.New("task watchdog timeout exceeded")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user