Files
clawgo/pkg/tools/subagent.go

472 lines
12 KiB
Go

package tools
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"clawgo/pkg/bus"
"clawgo/pkg/providers"
)
type SubagentTask struct {
ID string
Task string
Label string
Role string
AgentID string
SessionKey string
MemoryNS string
SystemPrompt string
ToolAllowlist []string
PipelineID string
PipelineTask string
SharedState map[string]interface{}
OriginChannel string
OriginChatID string
Status string
Result string
Steering []string
Created int64
Updated int64
}
type SubagentManager struct {
tasks map[string]*SubagentTask
cancelFuncs map[string]context.CancelFunc
archiveAfterMinute int64
mu sync.RWMutex
provider providers.LLMProvider
bus *bus.MessageBus
orc *Orchestrator
workspace string
nextID int
runFunc SubagentRunFunc
profileStore *SubagentProfileStore
}
type SubagentSpawnOptions struct {
Task string
Label string
Role string
AgentID string
OriginChannel string
OriginChatID string
PipelineID string
PipelineTask string
}
func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *bus.MessageBus, orc *Orchestrator) *SubagentManager {
store := NewSubagentProfileStore(workspace)
return &SubagentManager{
tasks: make(map[string]*SubagentTask),
cancelFuncs: make(map[string]context.CancelFunc),
archiveAfterMinute: 60,
provider: provider,
bus: bus,
orc: orc,
workspace: workspace,
nextID: 1,
profileStore: store,
}
}
func (sm *SubagentManager) Spawn(ctx context.Context, opts SubagentSpawnOptions) (string, error) {
task := strings.TrimSpace(opts.Task)
if task == "" {
return "", fmt.Errorf("task is required")
}
label := strings.TrimSpace(opts.Label)
role := strings.TrimSpace(opts.Role)
agentID := normalizeSubagentIdentifier(opts.AgentID)
originalRole := role
var profile *SubagentProfile
if sm.profileStore != nil {
if agentID != "" {
if p, ok, err := sm.profileStore.Get(agentID); err != nil {
return "", err
} else if ok {
profile = p
}
} else if role != "" {
if p, ok, err := sm.profileStore.FindByRole(role); err != nil {
return "", err
} else if ok {
profile = p
agentID = normalizeSubagentIdentifier(p.AgentID)
}
}
}
if agentID == "" {
agentID = normalizeSubagentIdentifier(role)
}
if agentID == "" {
agentID = "default"
}
memoryNS := agentID
systemPrompt := ""
toolAllowlist := []string(nil)
if profile == nil && sm.profileStore != nil {
if p, ok, err := sm.profileStore.Get(agentID); err != nil {
return "", err
} else if ok {
profile = p
}
}
if profile != nil {
if strings.EqualFold(strings.TrimSpace(profile.Status), "disabled") {
return "", fmt.Errorf("subagent profile '%s' is disabled", profile.AgentID)
}
if label == "" {
label = strings.TrimSpace(profile.Name)
}
if role == "" {
role = strings.TrimSpace(profile.Role)
}
if ns := normalizeSubagentIdentifier(profile.MemoryNamespace); ns != "" {
memoryNS = ns
}
systemPrompt = strings.TrimSpace(profile.SystemPrompt)
toolAllowlist = append([]string(nil), profile.ToolAllowlist...)
}
if role == "" {
role = originalRole
}
originChannel := strings.TrimSpace(opts.OriginChannel)
originChatID := strings.TrimSpace(opts.OriginChatID)
pipelineID := strings.TrimSpace(opts.PipelineID)
pipelineTask := strings.TrimSpace(opts.PipelineTask)
sm.mu.Lock()
defer sm.mu.Unlock()
taskID := fmt.Sprintf("subagent-%d", sm.nextID)
sm.nextID++
sessionKey := buildSubagentSessionKey(agentID, taskID)
now := time.Now().UnixMilli()
subagentTask := &SubagentTask{
ID: taskID,
Task: task,
Label: label,
Role: role,
AgentID: agentID,
SessionKey: sessionKey,
MemoryNS: memoryNS,
SystemPrompt: systemPrompt,
ToolAllowlist: toolAllowlist,
PipelineID: pipelineID,
PipelineTask: pipelineTask,
OriginChannel: originChannel,
OriginChatID: originChatID,
Status: "running",
Created: now,
Updated: now,
}
taskCtx, cancel := context.WithCancel(ctx)
sm.tasks[taskID] = subagentTask
sm.cancelFuncs[taskID] = cancel
go sm.runTask(taskCtx, subagentTask)
desc := fmt.Sprintf("Spawned subagent for task: %s (agent=%s)", task, agentID)
if label != "" {
desc = fmt.Sprintf("Spawned subagent '%s' for task: %s (agent=%s)", label, task, agentID)
}
if role != "" {
desc += fmt.Sprintf(" role=%s", role)
}
if pipelineID != "" && pipelineTask != "" {
desc += fmt.Sprintf(" (pipeline=%s task=%s)", pipelineID, pipelineTask)
}
return desc, nil
}
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
defer func() {
sm.mu.Lock()
delete(sm.cancelFuncs, task.ID)
sm.mu.Unlock()
}()
sm.mu.Lock()
task.Status = "running"
task.Created = time.Now().UnixMilli()
task.Updated = task.Created
sm.mu.Unlock()
if sm.orc != nil && task.PipelineID != "" && task.PipelineTask != "" {
_ = sm.orc.MarkTaskRunning(task.PipelineID, task.PipelineTask)
}
// 1. Independent agent logic: supports recursive tool calling.
// This lightweight approach reuses AgentLoop logic for full subagent capability.
// subagent.go cannot depend on agent package inversely, so use function injection.
// Fall back to one-shot chat when RunFunc is not injected.
if sm.runFunc != nil {
result, err := sm.runFunc(ctx, task)
sm.mu.Lock()
if err != nil {
task.Status = "failed"
task.Result = fmt.Sprintf("Error: %v", err)
task.Updated = time.Now().UnixMilli()
if sm.orc != nil && task.PipelineID != "" && task.PipelineTask != "" {
_ = sm.orc.MarkTaskDone(task.PipelineID, task.PipelineTask, task.Result, err)
}
} else {
task.Status = "completed"
task.Result = result
task.Updated = time.Now().UnixMilli()
if sm.orc != nil && task.PipelineID != "" && task.PipelineTask != "" {
_ = sm.orc.MarkTaskDone(task.PipelineID, task.PipelineTask, task.Result, nil)
}
}
sm.mu.Unlock()
} else {
// Original one-shot logic
if sm.provider == nil {
sm.mu.Lock()
task.Status = "failed"
task.Result = "Error: no llm provider configured for subagent execution"
task.Updated = time.Now().UnixMilli()
if sm.orc != nil && task.PipelineID != "" && task.PipelineTask != "" {
_ = sm.orc.MarkTaskDone(task.PipelineID, task.PipelineTask, task.Result, fmt.Errorf("no llm provider configured for subagent execution"))
}
sm.mu.Unlock()
return
}
systemPrompt := "You are a subagent. Follow workspace AGENTS.md and complete the task independently."
rolePrompt := strings.TrimSpace(task.SystemPrompt)
if ws := strings.TrimSpace(sm.workspace); ws != "" {
if data, err := os.ReadFile(filepath.Join(ws, "AGENTS.md")); err == nil {
txt := strings.TrimSpace(string(data))
if txt != "" {
systemPrompt = "Workspace policy (AGENTS.md):\n" + txt + "\n\nComplete the given task independently and report the result."
}
}
}
if rolePrompt != "" {
systemPrompt += "\n\nRole-specific profile prompt:\n" + rolePrompt
}
messages := []providers.Message{
{
Role: "system",
Content: systemPrompt,
},
{
Role: "user",
Content: task.Task,
},
}
response, err := sm.provider.Chat(ctx, messages, nil, sm.provider.GetDefaultModel(), map[string]interface{}{
"max_tokens": 4096,
})
sm.mu.Lock()
if err != nil {
task.Status = "failed"
task.Result = fmt.Sprintf("Error: %v", err)
task.Updated = time.Now().UnixMilli()
if sm.orc != nil && task.PipelineID != "" && task.PipelineTask != "" {
_ = sm.orc.MarkTaskDone(task.PipelineID, task.PipelineTask, task.Result, err)
}
} else {
task.Status = "completed"
task.Result = response.Content
task.Updated = time.Now().UnixMilli()
if sm.orc != nil && task.PipelineID != "" && task.PipelineTask != "" {
_ = sm.orc.MarkTaskDone(task.PipelineID, task.PipelineTask, task.Result, nil)
}
}
sm.mu.Unlock()
}
// 2. Result broadcast (keep existing behavior)
if sm.bus != nil {
prefix := "Task completed"
if task.Label != "" {
prefix = fmt.Sprintf("Task '%s' completed", task.Label)
}
announceContent := fmt.Sprintf("%s.\n\nResult:\n%s", prefix, task.Result)
if task.PipelineID != "" && task.PipelineTask != "" {
announceContent += fmt.Sprintf("\n\nPipeline: %s\nPipeline Task: %s", task.PipelineID, task.PipelineTask)
}
sm.bus.PublishInbound(bus.InboundMessage{
Channel: "system",
SenderID: fmt.Sprintf("subagent:%s", task.ID),
ChatID: fmt.Sprintf("%s:%s", task.OriginChannel, task.OriginChatID),
SessionKey: task.SessionKey,
Content: announceContent,
Metadata: map[string]string{
"trigger": "subagent",
"subagent_id": task.ID,
"agent_id": task.AgentID,
"role": task.Role,
"session_key": task.SessionKey,
"memory_ns": task.MemoryNS,
"pipeline_id": task.PipelineID,
"pipeline_task": task.PipelineTask,
},
})
}
}
type SubagentRunFunc func(ctx context.Context, task *SubagentTask) (string, error)
func (sm *SubagentManager) SetRunFunc(f SubagentRunFunc) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.runFunc = f
}
func (sm *SubagentManager) ProfileStore() *SubagentProfileStore {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.profileStore
}
func (sm *SubagentManager) GetTask(taskID string) (*SubagentTask, bool) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.pruneArchivedLocked()
task, ok := sm.tasks[taskID]
return task, ok
}
func (sm *SubagentManager) ListTasks() []*SubagentTask {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.pruneArchivedLocked()
tasks := make([]*SubagentTask, 0, len(sm.tasks))
for _, task := range sm.tasks {
tasks = append(tasks, task)
}
return tasks
}
func (sm *SubagentManager) KillTask(taskID string) bool {
sm.mu.Lock()
defer sm.mu.Unlock()
t, ok := sm.tasks[taskID]
if !ok {
return false
}
if cancel, ok := sm.cancelFuncs[taskID]; ok {
cancel()
delete(sm.cancelFuncs, taskID)
}
if t.Status == "running" {
t.Status = "killed"
t.Updated = time.Now().UnixMilli()
}
return true
}
func (sm *SubagentManager) SteerTask(taskID, message string) bool {
sm.mu.Lock()
defer sm.mu.Unlock()
t, ok := sm.tasks[taskID]
if !ok {
return false
}
message = strings.TrimSpace(message)
if message == "" {
return false
}
t.Steering = append(t.Steering, message)
t.Updated = time.Now().UnixMilli()
return true
}
func (sm *SubagentManager) ResumeTask(ctx context.Context, taskID string) (string, bool) {
sm.mu.RLock()
t, ok := sm.tasks[taskID]
sm.mu.RUnlock()
if !ok {
return "", false
}
if strings.TrimSpace(t.Task) == "" {
return "", false
}
label := strings.TrimSpace(t.Label)
if label == "" {
label = "resumed"
} else {
label = label + "-resumed"
}
_, err := sm.Spawn(ctx, SubagentSpawnOptions{
Task: t.Task,
Label: label,
Role: t.Role,
AgentID: t.AgentID,
OriginChannel: t.OriginChannel,
OriginChatID: t.OriginChatID,
PipelineID: t.PipelineID,
PipelineTask: t.PipelineTask,
})
if err != nil {
return "", false
}
return label, true
}
func (sm *SubagentManager) pruneArchivedLocked() {
if sm.archiveAfterMinute <= 0 {
return
}
cutoff := time.Now().Add(-time.Duration(sm.archiveAfterMinute) * time.Minute).UnixMilli()
for id, t := range sm.tasks {
if t.Status == "running" {
continue
}
if t.Updated > 0 && t.Updated < cutoff {
delete(sm.tasks, id)
delete(sm.cancelFuncs, id)
}
}
}
func normalizeSubagentIdentifier(in string) string {
in = strings.TrimSpace(strings.ToLower(in))
if in == "" {
return ""
}
var sb strings.Builder
for _, r := range in {
switch {
case r >= 'a' && r <= 'z':
sb.WriteRune(r)
case r >= '0' && r <= '9':
sb.WriteRune(r)
case r == '-' || r == '_' || r == '.':
sb.WriteRune(r)
case r == ' ':
sb.WriteRune('-')
}
}
out := strings.Trim(sb.String(), "-_.")
if out == "" {
return ""
}
return out
}
func buildSubagentSessionKey(agentID, taskID string) string {
a := normalizeSubagentIdentifier(agentID)
if a == "" {
a = "default"
}
t := normalizeSubagentIdentifier(taskID)
if t == "" {
t = "task"
}
return fmt.Sprintf("subagent:%s:%s", a, t)
}