mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-15 00:27:29 +08:00
feat: ship subagent runtime and remove autonomy/task legacy
This commit is contained in:
@@ -26,6 +26,7 @@ const (
|
||||
)
|
||||
|
||||
var ErrCommandNoProgress = errors.New("command no progress across tick rounds")
|
||||
var ErrCommandTickTimeout = errors.New("command tick timeout exceeded")
|
||||
|
||||
type commandRuntimePolicy struct {
|
||||
BaseTick time.Duration
|
||||
@@ -594,6 +595,75 @@ func runCommandWithDynamicTick(ctx context.Context, cmd *exec.Cmd, source, label
|
||||
}
|
||||
}
|
||||
|
||||
type stringTaskResult struct {
|
||||
output string
|
||||
err error
|
||||
}
|
||||
|
||||
// runStringTaskWithCommandTickTimeout executes a string-returning task with a
|
||||
// command-tick-based timeout loop so timeout behavior stays consistent with the
|
||||
// command watchdog pacing policy.
|
||||
func runStringTaskWithCommandTickTimeout(
|
||||
ctx context.Context,
|
||||
timeoutSec int,
|
||||
baseTick time.Duration,
|
||||
run func(context.Context) (string, error),
|
||||
) (string, error) {
|
||||
if run == nil {
|
||||
return "", fmt.Errorf("run function is nil")
|
||||
}
|
||||
if timeoutSec <= 0 {
|
||||
return run(ctx)
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
timeout := time.Duration(timeoutSec) * time.Second
|
||||
started := time.Now()
|
||||
tick := normalizeCommandTick(baseTick)
|
||||
if tick <= 0 {
|
||||
tick = 2 * time.Second
|
||||
}
|
||||
|
||||
runCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan stringTaskResult, 1)
|
||||
go func() {
|
||||
out, err := run(runCtx)
|
||||
done <- stringTaskResult{output: out, err: err}
|
||||
}()
|
||||
|
||||
timer := time.NewTimer(tick)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cancel()
|
||||
return "", ctx.Err()
|
||||
case res := <-done:
|
||||
return res.output, res.err
|
||||
case <-timer.C:
|
||||
elapsed := time.Since(started)
|
||||
if elapsed >= timeout {
|
||||
cancel()
|
||||
select {
|
||||
case res := <-done:
|
||||
if res.err != nil {
|
||||
return "", fmt.Errorf("%w: %v", ErrCommandTickTimeout, res.err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
return "", fmt.Errorf("%w: %ds", ErrCommandTickTimeout, timeoutSec)
|
||||
}
|
||||
next := nextCommandTick(tick, elapsed)
|
||||
timer.Reset(next)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (wd *commandWatchdog) buildQueueSnapshotLocked() map[string]interface{} {
|
||||
if wd == nil {
|
||||
return nil
|
||||
|
||||
@@ -49,6 +49,26 @@ func (t *SpawnTool) Parameters() map[string]interface{} {
|
||||
"type": "string",
|
||||
"description": "Optional logical agent ID. If omitted, role will be used as fallback.",
|
||||
},
|
||||
"max_retries": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "Optional retry limit for this task.",
|
||||
},
|
||||
"retry_backoff_ms": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "Optional retry backoff in milliseconds.",
|
||||
},
|
||||
"timeout_sec": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "Optional per-attempt timeout in seconds.",
|
||||
},
|
||||
"max_task_chars": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "Optional task size quota in characters.",
|
||||
},
|
||||
"max_result_chars": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "Optional result size quota in characters.",
|
||||
},
|
||||
"pipeline_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Optional pipeline ID for orchestrated multi-agent workflow",
|
||||
@@ -86,6 +106,11 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (s
|
||||
label, _ := args["label"].(string)
|
||||
role, _ := args["role"].(string)
|
||||
agentID, _ := args["agent_id"].(string)
|
||||
maxRetries := intArg(args, "max_retries")
|
||||
retryBackoff := intArg(args, "retry_backoff_ms")
|
||||
timeoutSec := intArg(args, "timeout_sec")
|
||||
maxTaskChars := intArg(args, "max_task_chars")
|
||||
maxResultChars := intArg(args, "max_result_chars")
|
||||
pipelineID, _ := args["pipeline_id"].(string)
|
||||
taskID, _ := args["task_id"].(string)
|
||||
if label == "" && role != "" {
|
||||
@@ -114,14 +139,19 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (s
|
||||
}
|
||||
|
||||
result, err := t.manager.Spawn(ctx, SubagentSpawnOptions{
|
||||
Task: task,
|
||||
Label: label,
|
||||
Role: role,
|
||||
AgentID: agentID,
|
||||
OriginChannel: originChannel,
|
||||
OriginChatID: originChatID,
|
||||
PipelineID: pipelineID,
|
||||
PipelineTask: taskID,
|
||||
Task: task,
|
||||
Label: label,
|
||||
Role: role,
|
||||
AgentID: agentID,
|
||||
MaxRetries: maxRetries,
|
||||
RetryBackoff: retryBackoff,
|
||||
TimeoutSec: timeoutSec,
|
||||
MaxTaskChars: maxTaskChars,
|
||||
MaxResultChars: maxResultChars,
|
||||
OriginChannel: originChannel,
|
||||
OriginChatID: originChatID,
|
||||
PipelineID: pipelineID,
|
||||
PipelineTask: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to spawn subagent: %w", err)
|
||||
@@ -129,3 +159,19 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (s
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func intArg(args map[string]interface{}, key string) int {
|
||||
if args == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := args[key].(float64); ok {
|
||||
return int(v)
|
||||
}
|
||||
if v, ok := args[key].(int); ok {
|
||||
return v
|
||||
}
|
||||
if v, ok := args[key].(int64); ok {
|
||||
return int(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -14,25 +14,31 @@ import (
|
||||
)
|
||||
|
||||
type SubagentTask struct {
|
||||
ID string
|
||||
Task string
|
||||
Label string
|
||||
Role string
|
||||
AgentID string
|
||||
SessionKey string
|
||||
MemoryNS string
|
||||
SystemPrompt string
|
||||
ToolAllowlist []string
|
||||
PipelineID string
|
||||
PipelineTask string
|
||||
SharedState map[string]interface{}
|
||||
OriginChannel string
|
||||
OriginChatID string
|
||||
Status string
|
||||
Result string
|
||||
Steering []string
|
||||
Created int64
|
||||
Updated int64
|
||||
ID string `json:"id"`
|
||||
Task string `json:"task"`
|
||||
Label string `json:"label"`
|
||||
Role string `json:"role"`
|
||||
AgentID string `json:"agent_id"`
|
||||
SessionKey string `json:"session_key"`
|
||||
MemoryNS string `json:"memory_ns"`
|
||||
SystemPrompt string `json:"system_prompt,omitempty"`
|
||||
ToolAllowlist []string `json:"tool_allowlist,omitempty"`
|
||||
MaxRetries int `json:"max_retries,omitempty"`
|
||||
RetryBackoff int `json:"retry_backoff,omitempty"`
|
||||
TimeoutSec int `json:"timeout_sec,omitempty"`
|
||||
MaxTaskChars int `json:"max_task_chars,omitempty"`
|
||||
MaxResultChars int `json:"max_result_chars,omitempty"`
|
||||
RetryCount int `json:"retry_count,omitempty"`
|
||||
PipelineID string `json:"pipeline_id,omitempty"`
|
||||
PipelineTask string `json:"pipeline_task,omitempty"`
|
||||
SharedState map[string]interface{} `json:"shared_state,omitempty"`
|
||||
OriginChannel string `json:"origin_channel,omitempty"`
|
||||
OriginChatID string `json:"origin_chat_id,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Result string `json:"result,omitempty"`
|
||||
Steering []string `json:"steering,omitempty"`
|
||||
Created int64 `json:"created"`
|
||||
Updated int64 `json:"updated"`
|
||||
}
|
||||
|
||||
type SubagentManager struct {
|
||||
@@ -50,14 +56,19 @@ type SubagentManager struct {
|
||||
}
|
||||
|
||||
type SubagentSpawnOptions struct {
|
||||
Task string
|
||||
Label string
|
||||
Role string
|
||||
AgentID string
|
||||
OriginChannel string
|
||||
OriginChatID string
|
||||
PipelineID string
|
||||
PipelineTask string
|
||||
Task string
|
||||
Label string
|
||||
Role string
|
||||
AgentID string
|
||||
MaxRetries int
|
||||
RetryBackoff int
|
||||
TimeoutSec int
|
||||
MaxTaskChars int
|
||||
MaxResultChars int
|
||||
OriginChannel string
|
||||
OriginChatID string
|
||||
PipelineID string
|
||||
PipelineTask string
|
||||
}
|
||||
|
||||
func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *bus.MessageBus, orc *Orchestrator) *SubagentManager {
|
||||
@@ -110,6 +121,11 @@ func (sm *SubagentManager) Spawn(ctx context.Context, opts SubagentSpawnOptions)
|
||||
memoryNS := agentID
|
||||
systemPrompt := ""
|
||||
toolAllowlist := []string(nil)
|
||||
maxRetries := 0
|
||||
retryBackoff := 1000
|
||||
timeoutSec := 0
|
||||
maxTaskChars := 0
|
||||
maxResultChars := 0
|
||||
if profile == nil && sm.profileStore != nil {
|
||||
if p, ok, err := sm.profileStore.Get(agentID); err != nil {
|
||||
return "", err
|
||||
@@ -132,7 +148,35 @@ func (sm *SubagentManager) Spawn(ctx context.Context, opts SubagentSpawnOptions)
|
||||
}
|
||||
systemPrompt = strings.TrimSpace(profile.SystemPrompt)
|
||||
toolAllowlist = append([]string(nil), profile.ToolAllowlist...)
|
||||
maxRetries = profile.MaxRetries
|
||||
retryBackoff = profile.RetryBackoff
|
||||
timeoutSec = profile.TimeoutSec
|
||||
maxTaskChars = profile.MaxTaskChars
|
||||
maxResultChars = profile.MaxResultChars
|
||||
}
|
||||
if opts.MaxRetries > 0 {
|
||||
maxRetries = opts.MaxRetries
|
||||
}
|
||||
if opts.RetryBackoff > 0 {
|
||||
retryBackoff = opts.RetryBackoff
|
||||
}
|
||||
if opts.TimeoutSec > 0 {
|
||||
timeoutSec = opts.TimeoutSec
|
||||
}
|
||||
if opts.MaxTaskChars > 0 {
|
||||
maxTaskChars = opts.MaxTaskChars
|
||||
}
|
||||
if opts.MaxResultChars > 0 {
|
||||
maxResultChars = opts.MaxResultChars
|
||||
}
|
||||
if maxTaskChars > 0 && len(task) > maxTaskChars {
|
||||
return "", fmt.Errorf("task exceeds max_task_chars quota (%d > %d)", len(task), maxTaskChars)
|
||||
}
|
||||
maxRetries = normalizePositiveBound(maxRetries, 0, 8)
|
||||
retryBackoff = normalizePositiveBound(retryBackoff, 500, 120000)
|
||||
timeoutSec = normalizePositiveBound(timeoutSec, 0, 3600)
|
||||
maxTaskChars = normalizePositiveBound(maxTaskChars, 0, 400000)
|
||||
maxResultChars = normalizePositiveBound(maxResultChars, 0, 400000)
|
||||
if role == "" {
|
||||
role = originalRole
|
||||
}
|
||||
@@ -150,22 +194,28 @@ func (sm *SubagentManager) Spawn(ctx context.Context, opts SubagentSpawnOptions)
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
subagentTask := &SubagentTask{
|
||||
ID: taskID,
|
||||
Task: task,
|
||||
Label: label,
|
||||
Role: role,
|
||||
AgentID: agentID,
|
||||
SessionKey: sessionKey,
|
||||
MemoryNS: memoryNS,
|
||||
SystemPrompt: systemPrompt,
|
||||
ToolAllowlist: toolAllowlist,
|
||||
PipelineID: pipelineID,
|
||||
PipelineTask: pipelineTask,
|
||||
OriginChannel: originChannel,
|
||||
OriginChatID: originChatID,
|
||||
Status: "running",
|
||||
Created: now,
|
||||
Updated: now,
|
||||
ID: taskID,
|
||||
Task: task,
|
||||
Label: label,
|
||||
Role: role,
|
||||
AgentID: agentID,
|
||||
SessionKey: sessionKey,
|
||||
MemoryNS: memoryNS,
|
||||
SystemPrompt: systemPrompt,
|
||||
ToolAllowlist: toolAllowlist,
|
||||
MaxRetries: maxRetries,
|
||||
RetryBackoff: retryBackoff,
|
||||
TimeoutSec: timeoutSec,
|
||||
MaxTaskChars: maxTaskChars,
|
||||
MaxResultChars: maxResultChars,
|
||||
RetryCount: 0,
|
||||
PipelineID: pipelineID,
|
||||
PipelineTask: pipelineTask,
|
||||
OriginChannel: originChannel,
|
||||
OriginChatID: originChatID,
|
||||
Status: "running",
|
||||
Created: now,
|
||||
Updated: now,
|
||||
}
|
||||
taskCtx, cancel := context.WithCancel(ctx)
|
||||
sm.tasks[taskID] = subagentTask
|
||||
@@ -202,90 +252,25 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
|
||||
if sm.orc != nil && task.PipelineID != "" && task.PipelineTask != "" {
|
||||
_ = sm.orc.MarkTaskRunning(task.PipelineID, task.PipelineTask)
|
||||
}
|
||||
|
||||
// 1. Independent agent logic: supports recursive tool calling.
|
||||
// This lightweight approach reuses AgentLoop logic for full subagent capability.
|
||||
// subagent.go cannot depend on agent package inversely, so use function injection.
|
||||
|
||||
// Fall back to one-shot chat when RunFunc is not injected.
|
||||
if sm.runFunc != nil {
|
||||
result, err := sm.runFunc(ctx, task)
|
||||
sm.mu.Lock()
|
||||
if err != nil {
|
||||
task.Status = "failed"
|
||||
task.Result = fmt.Sprintf("Error: %v", err)
|
||||
task.Updated = time.Now().UnixMilli()
|
||||
if sm.orc != nil && task.PipelineID != "" && task.PipelineTask != "" {
|
||||
_ = sm.orc.MarkTaskDone(task.PipelineID, task.PipelineTask, task.Result, err)
|
||||
}
|
||||
} else {
|
||||
task.Status = "completed"
|
||||
task.Result = result
|
||||
task.Updated = time.Now().UnixMilli()
|
||||
if sm.orc != nil && task.PipelineID != "" && task.PipelineTask != "" {
|
||||
_ = sm.orc.MarkTaskDone(task.PipelineID, task.PipelineTask, task.Result, nil)
|
||||
}
|
||||
result, runErr := sm.runWithRetry(ctx, task)
|
||||
sm.mu.Lock()
|
||||
if runErr != nil {
|
||||
task.Status = "failed"
|
||||
task.Result = fmt.Sprintf("Error: %v", runErr)
|
||||
task.Result = applySubagentResultQuota(task.Result, task.MaxResultChars)
|
||||
task.Updated = time.Now().UnixMilli()
|
||||
if sm.orc != nil && task.PipelineID != "" && task.PipelineTask != "" {
|
||||
_ = sm.orc.MarkTaskDone(task.PipelineID, task.PipelineTask, task.Result, runErr)
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
} else {
|
||||
// Original one-shot logic
|
||||
if sm.provider == nil {
|
||||
sm.mu.Lock()
|
||||
task.Status = "failed"
|
||||
task.Result = "Error: no llm provider configured for subagent execution"
|
||||
task.Updated = time.Now().UnixMilli()
|
||||
if sm.orc != nil && task.PipelineID != "" && task.PipelineTask != "" {
|
||||
_ = sm.orc.MarkTaskDone(task.PipelineID, task.PipelineTask, task.Result, fmt.Errorf("no llm provider configured for subagent execution"))
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
return
|
||||
task.Status = "completed"
|
||||
task.Result = applySubagentResultQuota(result, task.MaxResultChars)
|
||||
task.Updated = time.Now().UnixMilli()
|
||||
if sm.orc != nil && task.PipelineID != "" && task.PipelineTask != "" {
|
||||
_ = sm.orc.MarkTaskDone(task.PipelineID, task.PipelineTask, task.Result, nil)
|
||||
}
|
||||
systemPrompt := "You are a subagent. Follow workspace AGENTS.md and complete the task independently."
|
||||
rolePrompt := strings.TrimSpace(task.SystemPrompt)
|
||||
if ws := strings.TrimSpace(sm.workspace); ws != "" {
|
||||
if data, err := os.ReadFile(filepath.Join(ws, "AGENTS.md")); err == nil {
|
||||
txt := strings.TrimSpace(string(data))
|
||||
if txt != "" {
|
||||
systemPrompt = "Workspace policy (AGENTS.md):\n" + txt + "\n\nComplete the given task independently and report the result."
|
||||
}
|
||||
}
|
||||
}
|
||||
if rolePrompt != "" {
|
||||
systemPrompt += "\n\nRole-specific profile prompt:\n" + rolePrompt
|
||||
}
|
||||
messages := []providers.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: systemPrompt,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: task.Task,
|
||||
},
|
||||
}
|
||||
|
||||
response, err := sm.provider.Chat(ctx, messages, nil, sm.provider.GetDefaultModel(), map[string]interface{}{
|
||||
"max_tokens": 4096,
|
||||
})
|
||||
|
||||
sm.mu.Lock()
|
||||
if err != nil {
|
||||
task.Status = "failed"
|
||||
task.Result = fmt.Sprintf("Error: %v", err)
|
||||
task.Updated = time.Now().UnixMilli()
|
||||
if sm.orc != nil && task.PipelineID != "" && task.PipelineTask != "" {
|
||||
_ = sm.orc.MarkTaskDone(task.PipelineID, task.PipelineTask, task.Result, err)
|
||||
}
|
||||
} else {
|
||||
task.Status = "completed"
|
||||
task.Result = response.Content
|
||||
task.Updated = time.Now().UnixMilli()
|
||||
if sm.orc != nil && task.PipelineID != "" && task.PipelineTask != "" {
|
||||
_ = sm.orc.MarkTaskDone(task.PipelineID, task.PipelineTask, task.Result, nil)
|
||||
}
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
}
|
||||
sm.mu.Unlock()
|
||||
|
||||
// 2. Result broadcast (keep existing behavior)
|
||||
if sm.bus != nil {
|
||||
@@ -310,6 +295,8 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
|
||||
"role": task.Role,
|
||||
"session_key": task.SessionKey,
|
||||
"memory_ns": task.MemoryNS,
|
||||
"retry_count": fmt.Sprintf("%d", task.RetryCount),
|
||||
"timeout_sec": fmt.Sprintf("%d", task.TimeoutSec),
|
||||
"pipeline_id": task.PipelineID,
|
||||
"pipeline_task": task.PipelineTask,
|
||||
},
|
||||
@@ -317,6 +304,92 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) runWithRetry(ctx context.Context, task *SubagentTask) (string, error) {
|
||||
maxRetries := normalizePositiveBound(task.MaxRetries, 0, 8)
|
||||
backoffMs := normalizePositiveBound(task.RetryBackoff, 500, 120000)
|
||||
timeoutSec := normalizePositiveBound(task.TimeoutSec, 0, 3600)
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
result, err := runStringTaskWithCommandTickTimeout(
|
||||
ctx,
|
||||
timeoutSec,
|
||||
2*time.Second,
|
||||
func(runCtx context.Context) (string, error) {
|
||||
return sm.executeTaskOnce(runCtx, task)
|
||||
},
|
||||
)
|
||||
if err == nil {
|
||||
sm.mu.Lock()
|
||||
task.RetryCount = attempt
|
||||
task.Updated = time.Now().UnixMilli()
|
||||
sm.mu.Unlock()
|
||||
return result, nil
|
||||
}
|
||||
lastErr = err
|
||||
sm.mu.Lock()
|
||||
task.RetryCount = attempt
|
||||
task.Updated = time.Now().UnixMilli()
|
||||
sm.mu.Unlock()
|
||||
if attempt >= maxRetries {
|
||||
break
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
case <-time.After(time.Duration(backoffMs) * time.Millisecond):
|
||||
}
|
||||
}
|
||||
if lastErr == nil {
|
||||
lastErr = fmt.Errorf("subagent task failed with unknown error")
|
||||
}
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) executeTaskOnce(ctx context.Context, task *SubagentTask) (string, error) {
|
||||
if task == nil {
|
||||
return "", fmt.Errorf("subagent task is nil")
|
||||
}
|
||||
if sm.runFunc != nil {
|
||||
return sm.runFunc(ctx, task)
|
||||
}
|
||||
if sm.provider == nil {
|
||||
return "", fmt.Errorf("no llm provider configured for subagent execution")
|
||||
}
|
||||
|
||||
systemPrompt := "You are a subagent. Follow workspace AGENTS.md and complete the task independently."
|
||||
rolePrompt := strings.TrimSpace(task.SystemPrompt)
|
||||
if ws := strings.TrimSpace(sm.workspace); ws != "" {
|
||||
if data, err := os.ReadFile(filepath.Join(ws, "AGENTS.md")); err == nil {
|
||||
txt := strings.TrimSpace(string(data))
|
||||
if txt != "" {
|
||||
systemPrompt = "Workspace policy (AGENTS.md):\n" + txt + "\n\nComplete the given task independently and report the result."
|
||||
}
|
||||
}
|
||||
}
|
||||
if rolePrompt != "" {
|
||||
systemPrompt += "\n\nRole-specific profile prompt:\n" + rolePrompt
|
||||
}
|
||||
messages := []providers.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: systemPrompt,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: task.Task,
|
||||
},
|
||||
}
|
||||
|
||||
response, err := sm.provider.Chat(ctx, messages, nil, sm.provider.GetDefaultModel(), map[string]interface{}{
|
||||
"max_tokens": 4096,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return response.Content, nil
|
||||
}
|
||||
|
||||
type SubagentRunFunc func(ctx context.Context, task *SubagentTask) (string, error)
|
||||
|
||||
func (sm *SubagentManager) SetRunFunc(f SubagentRunFunc) {
|
||||
@@ -402,14 +475,19 @@ func (sm *SubagentManager) ResumeTask(ctx context.Context, taskID string) (strin
|
||||
label = label + "-resumed"
|
||||
}
|
||||
_, err := sm.Spawn(ctx, SubagentSpawnOptions{
|
||||
Task: t.Task,
|
||||
Label: label,
|
||||
Role: t.Role,
|
||||
AgentID: t.AgentID,
|
||||
OriginChannel: t.OriginChannel,
|
||||
OriginChatID: t.OriginChatID,
|
||||
PipelineID: t.PipelineID,
|
||||
PipelineTask: t.PipelineTask,
|
||||
Task: t.Task,
|
||||
Label: label,
|
||||
Role: t.Role,
|
||||
AgentID: t.AgentID,
|
||||
MaxRetries: t.MaxRetries,
|
||||
RetryBackoff: t.RetryBackoff,
|
||||
TimeoutSec: t.TimeoutSec,
|
||||
MaxTaskChars: t.MaxTaskChars,
|
||||
MaxResultChars: t.MaxResultChars,
|
||||
OriginChannel: t.OriginChannel,
|
||||
OriginChatID: t.OriginChatID,
|
||||
PipelineID: t.PipelineID,
|
||||
PipelineTask: t.PipelineTask,
|
||||
})
|
||||
if err != nil {
|
||||
return "", false
|
||||
@@ -433,6 +511,31 @@ func (sm *SubagentManager) pruneArchivedLocked() {
|
||||
}
|
||||
}
|
||||
|
||||
func normalizePositiveBound(v, min, max int) int {
|
||||
if v < min {
|
||||
return min
|
||||
}
|
||||
if max > 0 && v > max {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func applySubagentResultQuota(result string, maxChars int) string {
|
||||
if maxChars <= 0 {
|
||||
return result
|
||||
}
|
||||
if len(result) <= maxChars {
|
||||
return result
|
||||
}
|
||||
suffix := "\n\n[TRUNCATED: result exceeds max_result_chars quota]"
|
||||
trimmed := result[:maxChars]
|
||||
if len(trimmed)+len(suffix) > maxChars && maxChars > len(suffix) {
|
||||
trimmed = trimmed[:maxChars-len(suffix)]
|
||||
}
|
||||
return strings.TrimSpace(trimmed) + suffix
|
||||
}
|
||||
|
||||
func normalizeSubagentIdentifier(in string) string {
|
||||
in = strings.TrimSpace(strings.ToLower(in))
|
||||
if in == "" {
|
||||
|
||||
@@ -19,6 +19,11 @@ type SubagentProfile struct {
|
||||
SystemPrompt string `json:"system_prompt,omitempty"`
|
||||
ToolAllowlist []string `json:"tool_allowlist,omitempty"`
|
||||
MemoryNamespace string `json:"memory_namespace,omitempty"`
|
||||
MaxRetries int `json:"max_retries,omitempty"`
|
||||
RetryBackoff int `json:"retry_backoff_ms,omitempty"`
|
||||
TimeoutSec int `json:"timeout_sec,omitempty"`
|
||||
MaxTaskChars int `json:"max_task_chars,omitempty"`
|
||||
MaxResultChars int `json:"max_result_chars,omitempty"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
@@ -188,6 +193,11 @@ func normalizeSubagentProfile(in SubagentProfile) SubagentProfile {
|
||||
}
|
||||
p.Status = normalizeProfileStatus(p.Status)
|
||||
p.ToolAllowlist = normalizeToolAllowlist(p.ToolAllowlist)
|
||||
p.MaxRetries = clampInt(p.MaxRetries, 0, 8)
|
||||
p.RetryBackoff = clampInt(p.RetryBackoff, 500, 120000)
|
||||
p.TimeoutSec = clampInt(p.TimeoutSec, 0, 3600)
|
||||
p.MaxTaskChars = clampInt(p.MaxTaskChars, 0, 400000)
|
||||
p.MaxResultChars = clampInt(p.MaxResultChars, 0, 400000)
|
||||
return p
|
||||
}
|
||||
|
||||
@@ -223,18 +233,25 @@ func normalizeStringList(in []string) []string {
|
||||
}
|
||||
|
||||
func normalizeToolAllowlist(in []string) []string {
|
||||
items := normalizeStringList(in)
|
||||
items := ExpandToolAllowlistEntries(normalizeStringList(in))
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
for i := range items {
|
||||
items[i] = strings.ToLower(strings.TrimSpace(items[i]))
|
||||
}
|
||||
items = normalizeStringList(items)
|
||||
sort.Strings(items)
|
||||
return items
|
||||
}
|
||||
|
||||
func clampInt(v, min, max int) int {
|
||||
if v < min {
|
||||
return min
|
||||
}
|
||||
if max > 0 && v > max {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func parseStringList(raw interface{}) []string {
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
@@ -281,9 +298,15 @@ func (t *SubagentProfileTool) Parameters() map[string]interface{} {
|
||||
"memory_namespace": map[string]interface{}{"type": "string"},
|
||||
"status": map[string]interface{}{"type": "string", "description": "active|disabled"},
|
||||
"tool_allowlist": map[string]interface{}{
|
||||
"type": "array",
|
||||
"items": map[string]interface{}{"type": "string"},
|
||||
"type": "array",
|
||||
"description": "Tool allowlist entries. Supports tool names, '*'/'all', and grouped tokens like 'group:files_read' or '@pipeline'.",
|
||||
"items": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
"max_retries": map[string]interface{}{"type": "integer", "description": "Retry limit for subagent task execution."},
|
||||
"retry_backoff_ms": map[string]interface{}{"type": "integer", "description": "Backoff between retries in milliseconds."},
|
||||
"timeout_sec": map[string]interface{}{"type": "integer", "description": "Per-attempt timeout in seconds."},
|
||||
"max_task_chars": map[string]interface{}{"type": "integer", "description": "Task input size quota (characters)."},
|
||||
"max_result_chars": map[string]interface{}{"type": "integer", "description": "Result output size quota (characters)."},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
}
|
||||
@@ -344,6 +367,11 @@ func (t *SubagentProfileTool) Execute(ctx context.Context, args map[string]inter
|
||||
MemoryNamespace: stringArg(args, "memory_namespace"),
|
||||
Status: stringArg(args, "status"),
|
||||
ToolAllowlist: parseStringList(args["tool_allowlist"]),
|
||||
MaxRetries: profileIntArg(args, "max_retries"),
|
||||
RetryBackoff: profileIntArg(args, "retry_backoff_ms"),
|
||||
TimeoutSec: profileIntArg(args, "timeout_sec"),
|
||||
MaxTaskChars: profileIntArg(args, "max_task_chars"),
|
||||
MaxResultChars: profileIntArg(args, "max_result_chars"),
|
||||
}
|
||||
saved, err := t.store.Upsert(p)
|
||||
if err != nil {
|
||||
@@ -380,6 +408,21 @@ func (t *SubagentProfileTool) Execute(ctx context.Context, args map[string]inter
|
||||
if _, ok := args["tool_allowlist"]; ok {
|
||||
next.ToolAllowlist = parseStringList(args["tool_allowlist"])
|
||||
}
|
||||
if _, ok := args["max_retries"]; ok {
|
||||
next.MaxRetries = profileIntArg(args, "max_retries")
|
||||
}
|
||||
if _, ok := args["retry_backoff_ms"]; ok {
|
||||
next.RetryBackoff = profileIntArg(args, "retry_backoff_ms")
|
||||
}
|
||||
if _, ok := args["timeout_sec"]; ok {
|
||||
next.TimeoutSec = profileIntArg(args, "timeout_sec")
|
||||
}
|
||||
if _, ok := args["max_task_chars"]; ok {
|
||||
next.MaxTaskChars = profileIntArg(args, "max_task_chars")
|
||||
}
|
||||
if _, ok := args["max_result_chars"]; ok {
|
||||
next.MaxResultChars = profileIntArg(args, "max_result_chars")
|
||||
}
|
||||
saved, err := t.store.Upsert(next)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -423,3 +466,19 @@ func stringArg(args map[string]interface{}, key string) string {
|
||||
v, _ := args[key].(string)
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
|
||||
func profileIntArg(args map[string]interface{}, key string) int {
|
||||
if args == nil {
|
||||
return 0
|
||||
}
|
||||
switch v := args[key].(type) {
|
||||
case float64:
|
||||
return int(v)
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
123
pkg/tools/subagent_runtime_control_test.go
Normal file
123
pkg/tools/subagent_runtime_control_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSubagentSpawnEnforcesTaskQuota(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(nil, workspace, nil, nil)
|
||||
manager.SetRunFunc(func(ctx context.Context, task *SubagentTask) (string, error) {
|
||||
return "ok", nil
|
||||
})
|
||||
store := manager.ProfileStore()
|
||||
if store == nil {
|
||||
t.Fatalf("expected profile store")
|
||||
}
|
||||
if _, err := store.Upsert(SubagentProfile{
|
||||
AgentID: "coder",
|
||||
MaxTaskChars: 8,
|
||||
}); err != nil {
|
||||
t.Fatalf("failed to create profile: %v", err)
|
||||
}
|
||||
|
||||
_, err := manager.Spawn(context.Background(), SubagentSpawnOptions{
|
||||
Task: "this task is too long",
|
||||
AgentID: "coder",
|
||||
OriginChannel: "cli",
|
||||
OriginChatID: "direct",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatalf("expected max_task_chars quota to reject spawn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubagentRunWithRetryEventuallySucceeds(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(nil, workspace, nil, nil)
|
||||
attempts := 0
|
||||
manager.SetRunFunc(func(ctx context.Context, task *SubagentTask) (string, error) {
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
return "", errors.New("temporary failure")
|
||||
}
|
||||
return "retry success", nil
|
||||
})
|
||||
|
||||
_, err := manager.Spawn(context.Background(), SubagentSpawnOptions{
|
||||
Task: "retry task",
|
||||
AgentID: "coder",
|
||||
OriginChannel: "cli",
|
||||
OriginChatID: "direct",
|
||||
MaxRetries: 1,
|
||||
RetryBackoff: 1,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("spawn failed: %v", err)
|
||||
}
|
||||
|
||||
task := waitSubagentDone(t, manager, 4*time.Second)
|
||||
if task.Status != "completed" {
|
||||
t.Fatalf("expected completed task, got %s (%s)", task.Status, task.Result)
|
||||
}
|
||||
if task.RetryCount != 1 {
|
||||
t.Fatalf("expected retry_count=1, got %d", task.RetryCount)
|
||||
}
|
||||
if attempts < 2 {
|
||||
t.Fatalf("expected at least 2 attempts, got %d", attempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubagentRunWithTimeoutFails(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(nil, workspace, nil, nil)
|
||||
manager.SetRunFunc(func(ctx context.Context, task *SubagentTask) (string, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
case <-time.After(2 * time.Second):
|
||||
return "unexpected", nil
|
||||
}
|
||||
})
|
||||
|
||||
_, err := manager.Spawn(context.Background(), SubagentSpawnOptions{
|
||||
Task: "timeout task",
|
||||
AgentID: "coder",
|
||||
OriginChannel: "cli",
|
||||
OriginChatID: "direct",
|
||||
TimeoutSec: 1,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("spawn failed: %v", err)
|
||||
}
|
||||
|
||||
task := waitSubagentDone(t, manager, 4*time.Second)
|
||||
if task.Status != "failed" {
|
||||
t.Fatalf("expected failed task on timeout, got %s", task.Status)
|
||||
}
|
||||
if task.RetryCount != 0 {
|
||||
t.Fatalf("expected retry_count=0, got %d", task.RetryCount)
|
||||
}
|
||||
}
|
||||
|
||||
func waitSubagentDone(t *testing.T, manager *SubagentManager, timeout time.Duration) *SubagentTask {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
tasks := manager.ListTasks()
|
||||
if len(tasks) > 0 {
|
||||
task := tasks[0]
|
||||
if task.Status != "running" {
|
||||
return task
|
||||
}
|
||||
}
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("timeout waiting for subagent completion")
|
||||
return nil
|
||||
}
|
||||
@@ -62,8 +62,8 @@ func (t *SubagentsTool) Execute(ctx context.Context, args map[string]interface{}
|
||||
sb.WriteString("Subagents:\n")
|
||||
sort.Slice(tasks, func(i, j int) bool { return tasks[i].Created > tasks[j].Created })
|
||||
for i, task := range tasks {
|
||||
sb.WriteString(fmt.Sprintf("- #%d %s [%s] label=%s agent=%s role=%s session=%s allowlist=%d\n",
|
||||
i+1, task.ID, task.Status, task.Label, task.AgentID, task.Role, task.SessionKey, len(task.ToolAllowlist)))
|
||||
sb.WriteString(fmt.Sprintf("- #%d %s [%s] label=%s agent=%s role=%s session=%s allowlist=%d retry=%d timeout=%ds\n",
|
||||
i+1, task.ID, task.Status, task.Label, task.AgentID, task.Role, task.SessionKey, len(task.ToolAllowlist), task.MaxRetries, task.TimeoutSec))
|
||||
}
|
||||
return strings.TrimSpace(sb.String()), nil
|
||||
case "info":
|
||||
@@ -76,8 +76,8 @@ func (t *SubagentsTool) Execute(ctx context.Context, args map[string]interface{}
|
||||
var sb strings.Builder
|
||||
sb.WriteString("Subagents Summary:\n")
|
||||
for i, task := range tasks {
|
||||
sb.WriteString(fmt.Sprintf("- #%d %s [%s] label=%s agent=%s role=%s steering=%d allowlist=%d\n",
|
||||
i+1, task.ID, task.Status, task.Label, task.AgentID, task.Role, len(task.Steering), len(task.ToolAllowlist)))
|
||||
sb.WriteString(fmt.Sprintf("- #%d %s [%s] label=%s agent=%s role=%s steering=%d allowlist=%d retry=%d timeout=%ds\n",
|
||||
i+1, task.ID, task.Status, task.Label, task.AgentID, task.Role, len(task.Steering), len(task.ToolAllowlist), task.MaxRetries, task.TimeoutSec))
|
||||
}
|
||||
return strings.TrimSpace(sb.String()), nil
|
||||
}
|
||||
@@ -89,9 +89,10 @@ func (t *SubagentsTool) Execute(ctx context.Context, args map[string]interface{}
|
||||
if !ok {
|
||||
return "subagent not found", nil
|
||||
}
|
||||
return fmt.Sprintf("ID: %s\nStatus: %s\nLabel: %s\nAgent ID: %s\nRole: %s\nSession Key: %s\nMemory Namespace: %s\nTool Allowlist: %v\nCreated: %d\nUpdated: %d\nSteering Count: %d\nTask: %s\nResult:\n%s",
|
||||
return fmt.Sprintf("ID: %s\nStatus: %s\nLabel: %s\nAgent ID: %s\nRole: %s\nSession Key: %s\nMemory Namespace: %s\nTool Allowlist: %v\nMax Retries: %d\nRetry Count: %d\nRetry Backoff(ms): %d\nTimeout(s): %d\nMax Task Chars: %d\nMax Result Chars: %d\nCreated: %d\nUpdated: %d\nSteering Count: %d\nTask: %s\nResult:\n%s",
|
||||
task.ID, task.Status, task.Label, task.AgentID, task.Role, task.SessionKey, task.MemoryNS,
|
||||
task.ToolAllowlist, task.Created, task.Updated, len(task.Steering), task.Task, task.Result), nil
|
||||
task.ToolAllowlist, task.MaxRetries, task.RetryCount, task.RetryBackoff, task.TimeoutSec, task.MaxTaskChars, task.MaxResultChars,
|
||||
task.Created, task.Updated, len(task.Steering), task.Task, task.Result), nil
|
||||
case "kill":
|
||||
if strings.EqualFold(strings.TrimSpace(id), "all") {
|
||||
tasks := t.filterRecent(t.manager.ListTasks(), recentMinutes)
|
||||
@@ -138,7 +139,8 @@ func (t *SubagentsTool) Execute(ctx context.Context, args map[string]interface{}
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("Subagent %s Log\n", task.ID))
|
||||
sb.WriteString(fmt.Sprintf("Status: %s\n", task.Status))
|
||||
sb.WriteString(fmt.Sprintf("Agent ID: %s\nRole: %s\nSession Key: %s\nTool Allowlist: %v\n", task.AgentID, task.Role, task.SessionKey, task.ToolAllowlist))
|
||||
sb.WriteString(fmt.Sprintf("Agent ID: %s\nRole: %s\nSession Key: %s\nTool Allowlist: %v\nMax Retries: %d\nRetry Count: %d\nRetry Backoff(ms): %d\nTimeout(s): %d\n",
|
||||
task.AgentID, task.Role, task.SessionKey, task.ToolAllowlist, task.MaxRetries, task.RetryCount, task.RetryBackoff, task.TimeoutSec))
|
||||
if len(task.Steering) > 0 {
|
||||
sb.WriteString("Steering Messages:\n")
|
||||
for _, m := range task.Steering {
|
||||
|
||||
181
pkg/tools/tool_allowlist_groups.go
Normal file
181
pkg/tools/tool_allowlist_groups.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ToolAllowlistGroup struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Aliases []string `json:"aliases,omitempty"`
|
||||
Tools []string `json:"tools"`
|
||||
}
|
||||
|
||||
var defaultToolAllowlistGroups = []ToolAllowlistGroup{
|
||||
{
|
||||
Name: "files_read",
|
||||
Description: "Read-only workspace file tools",
|
||||
Aliases: []string{"file_read", "readonly_files"},
|
||||
Tools: []string{"read_file", "list_dir", "repo_map", "read"},
|
||||
},
|
||||
{
|
||||
Name: "files_write",
|
||||
Description: "Workspace file modification tools",
|
||||
Aliases: []string{"file_write"},
|
||||
Tools: []string{"write_file", "edit_file", "write", "edit"},
|
||||
},
|
||||
{
|
||||
Name: "memory_read",
|
||||
Description: "Read-only memory tools",
|
||||
Aliases: []string{"mem_read"},
|
||||
Tools: []string{"memory_search", "memory_get"},
|
||||
},
|
||||
{
|
||||
Name: "memory_write",
|
||||
Description: "Memory write tools",
|
||||
Aliases: []string{"mem_write"},
|
||||
Tools: []string{"memory_write"},
|
||||
},
|
||||
{
|
||||
Name: "memory_all",
|
||||
Description: "All memory tools",
|
||||
Aliases: []string{"memory"},
|
||||
Tools: []string{"memory_search", "memory_get", "memory_write"},
|
||||
},
|
||||
{
|
||||
Name: "pipeline",
|
||||
Description: "Pipeline orchestration tools",
|
||||
Aliases: []string{"pipelines"},
|
||||
Tools: []string{"pipeline_create", "pipeline_status", "pipeline_state_set", "pipeline_dispatch"},
|
||||
},
|
||||
{
|
||||
Name: "subagents",
|
||||
Description: "Subagent management tools",
|
||||
Aliases: []string{"subagent", "agent_runtime"},
|
||||
Tools: []string{"spawn", "subagents", "subagent_profile"},
|
||||
},
|
||||
}
|
||||
|
||||
func ToolAllowlistGroups() []ToolAllowlistGroup {
|
||||
out := make([]ToolAllowlistGroup, 0, len(defaultToolAllowlistGroups))
|
||||
for _, g := range defaultToolAllowlistGroups {
|
||||
item := ToolAllowlistGroup{
|
||||
Name: strings.ToLower(strings.TrimSpace(g.Name)),
|
||||
Description: strings.TrimSpace(g.Description),
|
||||
Aliases: normalizeAllowlistTokenList(g.Aliases),
|
||||
Tools: normalizeAllowlistTokenList(g.Tools),
|
||||
}
|
||||
if item.Name == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
sort.Slice(out, func(i, j int) bool { return out[i].Name < out[j].Name })
|
||||
return out
|
||||
}
|
||||
|
||||
func ExpandToolAllowlistEntries(entries []string) []string {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
groups := ToolAllowlistGroups()
|
||||
resolved := make(map[string][]string, len(groups))
|
||||
for _, g := range groups {
|
||||
if g.Name != "" {
|
||||
resolved[g.Name] = g.Tools
|
||||
}
|
||||
for _, alias := range g.Aliases {
|
||||
resolved[alias] = g.Tools
|
||||
}
|
||||
}
|
||||
|
||||
out := map[string]struct{}{}
|
||||
for _, raw := range entries {
|
||||
token := normalizeAllowlistToken(raw)
|
||||
if token == "" {
|
||||
continue
|
||||
}
|
||||
if token == "*" || token == "all" {
|
||||
out[token] = struct{}{}
|
||||
continue
|
||||
}
|
||||
|
||||
if groupName, isGroupToken := parseAllowlistGroupToken(token); isGroupToken {
|
||||
if members, ok := resolved[groupName]; ok {
|
||||
for _, name := range members {
|
||||
out[name] = struct{}{}
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Keep unknown group token as-is to preserve user intent and avoid silent mutation.
|
||||
out[token] = struct{}{}
|
||||
continue
|
||||
}
|
||||
|
||||
if members, ok := resolved[token]; ok {
|
||||
for _, name := range members {
|
||||
out[name] = struct{}{}
|
||||
}
|
||||
continue
|
||||
}
|
||||
out[token] = struct{}{}
|
||||
}
|
||||
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0, len(out))
|
||||
for name := range out {
|
||||
result = append(result, name)
|
||||
}
|
||||
sort.Strings(result)
|
||||
return result
|
||||
}
|
||||
|
||||
func parseAllowlistGroupToken(token string) (string, bool) {
|
||||
token = normalizeAllowlistToken(token)
|
||||
if token == "" {
|
||||
return "", false
|
||||
}
|
||||
if strings.HasPrefix(token, "group:") {
|
||||
v := normalizeAllowlistToken(strings.TrimPrefix(token, "group:"))
|
||||
if v != "" {
|
||||
return v, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
if strings.HasPrefix(token, "@") {
|
||||
v := normalizeAllowlistToken(strings.TrimPrefix(token, "@"))
|
||||
if v != "" {
|
||||
return v, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func normalizeAllowlistTokenList(in []string) []string {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := map[string]struct{}{}
|
||||
out := make([]string, 0, len(in))
|
||||
for _, item := range in {
|
||||
v := normalizeAllowlistToken(item)
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[v]; ok {
|
||||
continue
|
||||
}
|
||||
seen[v] = struct{}{}
|
||||
out = append(out, v)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeAllowlistToken(in string) string {
|
||||
return strings.ToLower(strings.TrimSpace(in))
|
||||
}
|
||||
31
pkg/tools/tool_allowlist_groups_test.go
Normal file
31
pkg/tools/tool_allowlist_groups_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package tools
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestExpandToolAllowlistEntries_GroupPrefix(t *testing.T) {
|
||||
got := ExpandToolAllowlistEntries([]string{"group:files_read"})
|
||||
contains := map[string]bool{}
|
||||
for _, item := range got {
|
||||
contains[item] = true
|
||||
}
|
||||
if !contains["read_file"] || !contains["list_dir"] {
|
||||
t.Fatalf("files_read group expansion missing expected tools: %v", got)
|
||||
}
|
||||
if contains["write_file"] {
|
||||
t.Fatalf("files_read group should not include write_file: %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandToolAllowlistEntries_BareGroupAndAlias(t *testing.T) {
|
||||
got := ExpandToolAllowlistEntries([]string{"memory_all", "@pipeline"})
|
||||
contains := map[string]bool{}
|
||||
for _, item := range got {
|
||||
contains[item] = true
|
||||
}
|
||||
if !contains["memory_search"] || !contains["memory_write"] {
|
||||
t.Fatalf("memory_all expansion missing memory tools: %v", got)
|
||||
}
|
||||
if !contains["pipeline_dispatch"] || !contains["pipeline_status"] {
|
||||
t.Fatalf("pipeline alias expansion missing pipeline tools: %v", got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user