diff --git a/Makefile b/Makefile index 1e41c73..c329f20 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: all build build-all install uninstall clean help test install-bootstrap-docs sync-embed-workspace cleanup-embed-workspace +.PHONY: all build build-all install uninstall clean help test install-bootstrap-docs sync-embed-workspace cleanup-embed-workspace test-only clean-test-artifacts # Build variables BINARY_NAME=clawgo @@ -142,9 +142,15 @@ clean: @rm -rf $(BUILD_DIR) @echo "Clean complete" -## fmt: Format Go code -fmt: - @$(GO) fmt ./... +## test-only: Run tests without leaving build artifacts (cleans embed workspace and test cache) +test-only: sync-embed-workspace + @echo "Running tests..." + @set -e; trap '$(MAKE) cleanup-embed-workspace clean-test-artifacts' EXIT; \ + $(GO) test ./... + +## clean-test-artifacts: Remove test caches/artifacts generated by go test +clean-test-artifacts: + @$(GO) clean -testcache ## deps: Update dependencies deps: diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index bd325dd..1473b1c 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -479,7 +479,6 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers toolsRegistry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults)) webFetchTool := tools.NewWebFetchTool(50000) toolsRegistry.Register(webFetchTool) - toolsRegistry.Register(tools.NewParallelFetchTool(webFetchTool)) // Register message tool messageTool := tools.NewMessageTool() @@ -514,9 +513,6 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers toolsRegistry.Register(tools.NewRepoMapTool(workspace)) toolsRegistry.Register(tools.NewSkillExecTool(workspace)) - // Register parallel execution tool (leveraging Go's concurrency) - toolsRegistry.Register(tools.NewParallelTool(toolsRegistry)) - // Register browser tool (integrated Chromium support) toolsRegistry.Register(tools.NewBrowserTool()) @@ -555,6 +551,8 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers policy := loadControlPolicyFromConfig(defaultControlPolicy(), cfg.Agents.Defaults.RuntimeControl) policy = applyLegacyControlPolicyEnvOverrides(policy) parallelSafeTools, maxParallelCalls := loadToolParallelPolicyFromConfig(cfg.Agents.Defaults.RuntimeControl) + toolsRegistry.Register(tools.NewParallelTool(toolsRegistry, maxParallelCalls, parallelSafeTools)) + toolsRegistry.Register(tools.NewParallelFetchTool(webFetchTool, maxParallelCalls, parallelSafeTools)) runStateTTL, runStateMax := loadRunStatePolicyFromConfig(cfg.Agents.Defaults.RuntimeControl) // Keep compatibility with older env names. runStateTTL = envDuration("CLAWGO_RUN_STATE_TTL", runStateTTL) diff --git a/pkg/tools/parallel.go b/pkg/tools/parallel.go index bb2c36b..fcfa78d 100644 --- a/pkg/tools/parallel.go +++ b/pkg/tools/parallel.go @@ -3,17 +3,36 @@ package tools import ( "context" "fmt" + "sort" + "strings" "sync" ) -const maxParallelToolCalls = 8 +const ( + maxParallelToolCallsLimit = 8 + defaultParallelToolCalls = 2 +) -type ParallelTool struct { - registry *ToolRegistry +type parallelCall struct { + Index int + Tool string + Args map[string]interface{} + ResultID string } -func NewParallelTool(registry *ToolRegistry) *ParallelTool { - return &ParallelTool{registry: registry} +type ParallelTool struct { + registry *ToolRegistry + maxParallelCalls int + parallelSafe map[string]struct{} +} + +func NewParallelTool(registry *ToolRegistry, maxParallelCalls int, parallelSafe map[string]struct{}) *ParallelTool { + limit := normalizeParallelLimit(maxParallelCalls) + return &ParallelTool{ + registry: registry, + maxParallelCalls: limit, + parallelSafe: normalizeSafeToolNames(parallelSafe), + } } func (t *ParallelTool) Name() string { @@ -60,11 +79,7 @@ func (t *ParallelTool) Execute(ctx context.Context, args map[string]interface{}) return "", fmt.Errorf("calls must be an array") } - results := make(map[string]string) - var mu sync.Mutex - var wg sync.WaitGroup - sem := make(chan struct{}, maxParallelToolCalls) - + calls := make([]parallelCall, 0, len(callsRaw)) for i, c := range callsRaw { call, ok := c.(map[string]interface{}) if !ok { @@ -78,31 +93,236 @@ func (t *ParallelTool) Execute(ctx context.Context, args map[string]interface{}) id = fmt.Sprintf("call_%d_%s", i, toolName) } - wg.Add(1) - go func(id, name string, args map[string]interface{}) { - defer wg.Done() - sem <- struct{}{} - defer func() { <-sem }() - - res, err := t.registry.Execute(ctx, name, args) - - mu.Lock() - defer mu.Unlock() - if err != nil { - results[id] = fmt.Sprintf("Error: %v", err) - } else { - results[id] = res - } - }(id, toolName, toolArgs) + calls = append(calls, parallelCall{ + Index: i, + Tool: toolName, + Args: toolArgs, + ResultID: id, + }) } - wg.Wait() - - // Simple string representation for the agent - var output string - for id, res := range results { - output += fmt.Sprintf("--- Result for %s ---\n%s\n", id, res) + if len(calls) == 0 { + return "", fmt.Errorf("no valid calls provided") } - return output, nil + if !t.callsParallelSafe(calls) { + return t.executeSerial(ctx, calls), nil + } + + batches := buildParallelBatches(t.registry, calls) + var output strings.Builder + for _, batch := range batches { + if len(batch) == 0 { + continue + } + if len(batch) == 1 || t.maxParallelCalls <= 1 { + output.WriteString(t.executeSerial(ctx, batch)) + continue + } + output.WriteString(t.executeParallel(ctx, batch)) + } + return output.String(), nil +} + +func (t *ParallelTool) executeSerial(ctx context.Context, calls []parallelCall) string { + results := make([]parallelResult, 0, len(calls)) + for _, call := range calls { + res, err := t.registry.Execute(ctx, call.Tool, call.Args) + results = append(results, parallelResult{ + Index: call.Index, + ResultID: call.ResultID, + Output: formatToolResult(res, err), + }) + } + return formatParallelResults(results) +} + +func (t *ParallelTool) executeParallel(ctx context.Context, calls []parallelCall) string { + limit := t.maxParallelCalls + if limit <= 0 { + limit = defaultParallelToolCalls + } + if limit > len(calls) { + limit = len(calls) + } + + results := make([]parallelResult, len(calls)) + var wg sync.WaitGroup + sem := make(chan struct{}, limit) + for i, call := range calls { + wg.Add(1) + sem <- struct{}{} + go func(index int, call parallelCall) { + defer wg.Done() + defer func() { <-sem }() + res, err := t.registry.Execute(ctx, call.Tool, call.Args) + results[index] = parallelResult{ + Index: call.Index, + ResultID: call.ResultID, + Output: formatToolResult(res, err), + } + }(i, call) + } + wg.Wait() + return formatParallelResults(results) +} + +func (t *ParallelTool) callsParallelSafe(calls []parallelCall) bool { + for _, call := range calls { + name := strings.ToLower(strings.TrimSpace(call.Tool)) + if name == "" { + return false + } + if tool, ok := t.registry.Get(call.Tool); ok { + if ps, ok := tool.(ParallelSafeTool); ok { + if !ps.ParallelSafe() { + return false + } + continue + } + } + if _, ok := t.parallelSafe[name]; !ok { + return false + } + } + return true +} + +func buildParallelBatches(registry *ToolRegistry, calls []parallelCall) [][]parallelCall { + if len(calls) == 0 { + return nil + } + batches := make([][]parallelCall, 0, len(calls)) + current := make([]parallelCall, 0, len(calls)) + used := map[string]struct{}{} + + flush := func() { + if len(current) == 0 { + return + } + batch := append([]parallelCall(nil), current...) + batches = append(batches, batch) + current = current[:0] + used = map[string]struct{}{} + } + + for _, call := range calls { + keys := toolResourceKeys(registry, call.Tool, call.Args) + if len(current) > 0 && hasResourceKeyConflict(used, keys) { + flush() + } + current = append(current, call) + for _, k := range keys { + used[k] = struct{}{} + } + } + flush() + return batches +} + +func toolResourceKeys(registry *ToolRegistry, name string, args map[string]interface{}) []string { + raw := strings.TrimSpace(name) + lower := strings.ToLower(raw) + if raw == "" || registry == nil { + return nil + } + tool, ok := registry.Get(raw) + if !ok && lower != raw { + tool, ok = registry.Get(lower) + } + if !ok || tool == nil { + return nil + } + rs, ok := tool.(ResourceScopedTool) + if !ok { + return nil + } + return normalizeResourceKeys(rs.ResourceKeys(args)) +} + +func normalizeResourceKeys(keys []string) []string { + if len(keys) == 0 { + return nil + } + out := make([]string, 0, len(keys)) + seen := make(map[string]struct{}, len(keys)) + for _, k := range keys { + n := strings.ToLower(strings.TrimSpace(k)) + if n == "" { + continue + } + if _, ok := seen[n]; ok { + continue + } + seen[n] = struct{}{} + out = append(out, n) + } + return out +} + +func hasResourceKeyConflict(used map[string]struct{}, keys []string) bool { + if len(keys) == 0 || len(used) == 0 { + return false + } + for _, k := range keys { + if _, ok := used[k]; ok { + return true + } + } + return false +} + +type parallelResult struct { + Index int + ResultID string + Output string +} + +func formatParallelResults(results []parallelResult) string { + sort.SliceStable(results, func(i, j int) bool { + return results[i].Index < results[j].Index + }) + var output strings.Builder + for _, res := range results { + if res.ResultID == "" { + continue + } + output.WriteString(fmt.Sprintf("--- Result for %s ---\n%s\n", res.ResultID, res.Output)) + } + return output.String() +} + +func formatToolResult(result string, err error) string { + if err != nil { + return fmt.Sprintf("Error: %v", err) + } + return result +} + +func normalizeParallelLimit(limit int) int { + if limit <= 0 { + limit = defaultParallelToolCalls + } + if limit < 1 { + limit = 1 + } + if limit > maxParallelToolCallsLimit { + limit = maxParallelToolCallsLimit + } + return limit +} + +func normalizeSafeToolNames(names map[string]struct{}) map[string]struct{} { + if len(names) == 0 { + return map[string]struct{}{} + } + out := make(map[string]struct{}, len(names)) + for name := range names { + n := strings.ToLower(strings.TrimSpace(name)) + if n == "" { + continue + } + out[n] = struct{}{} + } + return out } diff --git a/pkg/tools/parallel_fetch.go b/pkg/tools/parallel_fetch.go index 76903aa..f96c0b2 100644 --- a/pkg/tools/parallel_fetch.go +++ b/pkg/tools/parallel_fetch.go @@ -3,17 +3,23 @@ package tools import ( "context" "fmt" + "strings" "sync" ) -const maxParallelFetchCalls = 8 - type ParallelFetchTool struct { - fetcher *WebFetchTool + fetcher *WebFetchTool + maxParallelCalls int + parallelSafe map[string]struct{} } -func NewParallelFetchTool(fetcher *WebFetchTool) *ParallelFetchTool { - return &ParallelFetchTool{fetcher: fetcher} +func NewParallelFetchTool(fetcher *WebFetchTool, maxParallelCalls int, parallelSafe map[string]struct{}) *ParallelFetchTool { + limit := normalizeParallelLimit(maxParallelCalls) + return &ParallelFetchTool{ + fetcher: fetcher, + maxParallelCalls: limit, + parallelSafe: normalizeSafeToolNames(parallelSafe), + } } func (t *ParallelFetchTool) Name() string { @@ -46,20 +52,30 @@ func (t *ParallelFetchTool) Execute(ctx context.Context, args map[string]interfa return "", fmt.Errorf("urls must be an array") } + maxParallel := t.maxParallelCalls + if maxParallel <= 1 { + return t.executeSerial(ctx, urlsRaw), nil + } + + if !t.isParallelSafe() { + return t.executeSerial(ctx, urlsRaw), nil + } + results := make([]string, len(urlsRaw)) var wg sync.WaitGroup - sem := make(chan struct{}, maxParallelFetchCalls) + sem := make(chan struct{}, minParallelLimit(maxParallel, len(urlsRaw))) for i, u := range urlsRaw { urlStr, ok := u.(string) if !ok { + results[i] = "Error: invalid url" continue } wg.Add(1) + sem <- struct{}{} go func(index int, url string) { defer wg.Done() - sem <- struct{}{} defer func() { <-sem }() res, err := t.fetcher.Execute(ctx, map[string]interface{}{"url": url}) @@ -73,10 +89,55 @@ func (t *ParallelFetchTool) Execute(ctx context.Context, args map[string]interfa wg.Wait() - var output string - for i, res := range results { - output += fmt.Sprintf("=== Result %d ===\n%s\n\n", i+1, res) - } - - return output, nil + return formatFetchResults(results), nil +} + +func (t *ParallelFetchTool) executeSerial(ctx context.Context, urlsRaw []interface{}) string { + results := make([]string, len(urlsRaw)) + for i, u := range urlsRaw { + urlStr, ok := u.(string) + if !ok { + results[i] = "Error: invalid url" + continue + } + res, err := t.fetcher.Execute(ctx, map[string]interface{}{"url": urlStr}) + if err != nil { + results[i] = fmt.Sprintf("Error fetching %s: %v", urlStr, err) + } else { + results[i] = res + } + } + return formatFetchResults(results) +} + +func (t *ParallelFetchTool) isParallelSafe() bool { + if t.parallelSafe == nil { + return false + } + if tool, ok := any(t.fetcher).(ParallelSafeTool); ok { + return tool.ParallelSafe() + } + _, ok := t.parallelSafe["web_fetch"] + return ok +} + +func formatFetchResults(results []string) string { + var output strings.Builder + for i, res := range results { + output.WriteString(fmt.Sprintf("=== Result %d ===\n%s\n\n", i+1, res)) + } + return output.String() +} + +func minParallelLimit(maxParallel, total int) int { + if maxParallel <= 0 { + return 1 + } + if total <= 0 { + return maxParallel + } + if maxParallel > total { + return total + } + return maxParallel } diff --git a/pkg/tools/parallel_test.go b/pkg/tools/parallel_test.go new file mode 100644 index 0000000..15e5291 --- /dev/null +++ b/pkg/tools/parallel_test.go @@ -0,0 +1,278 @@ +package tools + +import ( + "context" + "errors" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +type basicTool struct { + name string + execute func(ctx context.Context, args map[string]interface{}) (string, error) +} + +func (t *basicTool) Name() string { + return t.name +} + +func (t *basicTool) Description() string { + return "test tool" +} + +func (t *basicTool) Parameters() map[string]interface{} { + return map[string]interface{}{"type": "object"} +} + +func (t *basicTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { + return t.execute(ctx, args) +} + +type safeTool struct { + *basicTool +} + +func (t *safeTool) ParallelSafe() bool { + return true +} + +type concurrencyTool struct { + name string + delay time.Duration + current int32 + max int32 +} + +func (t *concurrencyTool) Name() string { + return t.name +} + +func (t *concurrencyTool) Description() string { + return "concurrency test tool" +} + +func (t *concurrencyTool) Parameters() map[string]interface{} { + return map[string]interface{}{"type": "object"} +} + +func (t *concurrencyTool) ParallelSafe() bool { + return true +} + +func (t *concurrencyTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { + current := atomic.AddInt32(&t.current, 1) + for { + max := atomic.LoadInt32(&t.max) + if current <= max { + break + } + if atomic.CompareAndSwapInt32(&t.max, max, current) { + break + } + } + time.Sleep(t.delay) + atomic.AddInt32(&t.current, -1) + return "ok", nil +} + +type conflictTool struct { + name string + delay time.Duration + mu sync.Mutex + active map[string]bool + conflicts int32 +} + +func (t *conflictTool) Name() string { + return t.name +} + +func (t *conflictTool) Description() string { + return "resource conflict test tool" +} + +func (t *conflictTool) Parameters() map[string]interface{} { + return map[string]interface{}{"type": "object"} +} + +func (t *conflictTool) ParallelSafe() bool { + return true +} + +func (t *conflictTool) ResourceKeys(args map[string]interface{}) []string { + key, _ := args["key"].(string) + if key == "" { + return nil + } + return []string{key} +} + +func (t *conflictTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { + key, _ := args["key"].(string) + if key == "" { + return "", errors.New("missing key") + } + defer func() { + t.mu.Lock() + delete(t.active, key) + t.mu.Unlock() + }() + + t.mu.Lock() + if t.active == nil { + t.active = make(map[string]bool) + } + if t.active[key] { + atomic.AddInt32(&t.conflicts, 1) + } + t.active[key] = true + t.mu.Unlock() + + time.Sleep(t.delay) + return "ok", nil +} + +func TestParallelToolStableOrdering(t *testing.T) { + registry := NewToolRegistry() + tool := &safeTool{&basicTool{ + name: "echo", + execute: func(ctx context.Context, args map[string]interface{}) (string, error) { + delay := 0 * time.Millisecond + switch v := args["delay"].(type) { + case int: + delay = time.Duration(v) * time.Millisecond + case float64: + delay = time.Duration(v) * time.Millisecond + } + if delay > 0 { + time.Sleep(delay) + } + value, _ := args["value"].(string) + return value, nil + }, + }} + registry.Register(tool) + + parallel := NewParallelTool(registry, 3, nil) + calls := []interface{}{ + map[string]interface{}{ + "tool": "echo", + "arguments": map[string]interface{}{"value": "first", "delay": 40}, + "id": "first", + }, + map[string]interface{}{ + "tool": "echo", + "arguments": map[string]interface{}{"value": "second", "delay": 10}, + "id": "second", + }, + map[string]interface{}{ + "tool": "echo", + "arguments": map[string]interface{}{"value": "third", "delay": 20}, + "id": "third", + }, + } + + output, err := parallel.Execute(context.Background(), map[string]interface{}{"calls": calls}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + firstIdx := strings.Index(output, "Result for first") + secondIdx := strings.Index(output, "Result for second") + thirdIdx := strings.Index(output, "Result for third") + if firstIdx == -1 || secondIdx == -1 || thirdIdx == -1 { + t.Fatalf("missing result markers in output: %s", output) + } + if !(firstIdx < secondIdx && secondIdx < thirdIdx) { + t.Fatalf("results not in call order: %s", output) + } +} + +func TestParallelToolErrorFormatting(t *testing.T) { + registry := NewToolRegistry() + tool := &safeTool{&basicTool{ + name: "fail", + execute: func(ctx context.Context, args map[string]interface{}) (string, error) { + return "", errors.New("boom") + }, + }} + registry.Register(tool) + + parallel := NewParallelTool(registry, 2, nil) + calls := []interface{}{ + map[string]interface{}{ + "tool": "fail", + "arguments": map[string]interface{}{}, + "id": "err", + }, + } + + output, err := parallel.Execute(context.Background(), map[string]interface{}{"calls": calls}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(output, "Error: boom") { + t.Fatalf("expected formatted error, got: %s", output) + } +} + +func TestParallelToolConcurrencyLimit(t *testing.T) { + registry := NewToolRegistry() + tool := &concurrencyTool{name: "sleep", delay: 25 * time.Millisecond} + registry.Register(tool) + + parallel := NewParallelTool(registry, 2, nil) + calls := make([]interface{}, 5) + for i := 0; i < len(calls); i++ { + calls[i] = map[string]interface{}{ + "tool": "sleep", + "arguments": map[string]interface{}{}, + "id": "call", + } + } + + _, err := parallel.Execute(context.Background(), map[string]interface{}{"calls": calls}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if max := atomic.LoadInt32(&tool.max); max > 2 { + t.Fatalf("expected max concurrency <= 2, got %d", max) + } +} + +func TestParallelToolResourceBatching(t *testing.T) { + registry := NewToolRegistry() + tool := &conflictTool{name: "resource", delay: 30 * time.Millisecond} + registry.Register(tool) + + parallel := NewParallelTool(registry, 3, nil) + calls := []interface{}{ + map[string]interface{}{ + "tool": "resource", + "arguments": map[string]interface{}{"key": "alpha"}, + "id": "first", + }, + map[string]interface{}{ + "tool": "resource", + "arguments": map[string]interface{}{"key": "beta"}, + "id": "second", + }, + map[string]interface{}{ + "tool": "resource", + "arguments": map[string]interface{}{"key": "alpha"}, + "id": "third", + }, + } + + _, err := parallel.Execute(context.Background(), map[string]interface{}{"calls": calls}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if conflicts := atomic.LoadInt32(&tool.conflicts); conflicts > 0 { + t.Fatalf("expected no resource conflicts, got %d", conflicts) + } +} diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 81a6523..faef9a4 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -12,7 +12,6 @@ import ( "time" "clawgo/pkg/logger" - ) const ( @@ -144,6 +143,10 @@ type WebFetchTool struct { maxChars int } +func (t *WebFetchTool) ParallelSafe() bool { + return true +} + func NewWebFetchTool(maxChars int) *WebFetchTool { if maxChars <= 0 { maxChars = 50000