mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-12 19:57:29 +08:00
tools: align parallel execution with runtime limits
This commit is contained in:
14
Makefile
14
Makefile
@@ -1,4 +1,4 @@
|
||||
.PHONY: all build build-all install uninstall clean help test install-bootstrap-docs sync-embed-workspace cleanup-embed-workspace
|
||||
.PHONY: all build build-all install uninstall clean help test install-bootstrap-docs sync-embed-workspace cleanup-embed-workspace test-only clean-test-artifacts
|
||||
|
||||
# Build variables
|
||||
BINARY_NAME=clawgo
|
||||
@@ -142,9 +142,15 @@ clean:
|
||||
@rm -rf $(BUILD_DIR)
|
||||
@echo "Clean complete"
|
||||
|
||||
## fmt: Format Go code
|
||||
fmt:
|
||||
@$(GO) fmt ./...
|
||||
## test-only: Run tests without leaving build artifacts (cleans embed workspace and test cache)
|
||||
test-only: sync-embed-workspace
|
||||
@echo "Running tests..."
|
||||
@set -e; trap '$(MAKE) cleanup-embed-workspace clean-test-artifacts' EXIT; \
|
||||
$(GO) test ./...
|
||||
|
||||
## clean-test-artifacts: Remove test caches/artifacts generated by go test
|
||||
clean-test-artifacts:
|
||||
@$(GO) clean -testcache
|
||||
|
||||
## deps: Update dependencies
|
||||
deps:
|
||||
|
||||
@@ -479,7 +479,6 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
toolsRegistry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults))
|
||||
webFetchTool := tools.NewWebFetchTool(50000)
|
||||
toolsRegistry.Register(webFetchTool)
|
||||
toolsRegistry.Register(tools.NewParallelFetchTool(webFetchTool))
|
||||
|
||||
// Register message tool
|
||||
messageTool := tools.NewMessageTool()
|
||||
@@ -514,9 +513,6 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
toolsRegistry.Register(tools.NewRepoMapTool(workspace))
|
||||
toolsRegistry.Register(tools.NewSkillExecTool(workspace))
|
||||
|
||||
// Register parallel execution tool (leveraging Go's concurrency)
|
||||
toolsRegistry.Register(tools.NewParallelTool(toolsRegistry))
|
||||
|
||||
// Register browser tool (integrated Chromium support)
|
||||
toolsRegistry.Register(tools.NewBrowserTool())
|
||||
|
||||
@@ -555,6 +551,8 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
policy := loadControlPolicyFromConfig(defaultControlPolicy(), cfg.Agents.Defaults.RuntimeControl)
|
||||
policy = applyLegacyControlPolicyEnvOverrides(policy)
|
||||
parallelSafeTools, maxParallelCalls := loadToolParallelPolicyFromConfig(cfg.Agents.Defaults.RuntimeControl)
|
||||
toolsRegistry.Register(tools.NewParallelTool(toolsRegistry, maxParallelCalls, parallelSafeTools))
|
||||
toolsRegistry.Register(tools.NewParallelFetchTool(webFetchTool, maxParallelCalls, parallelSafeTools))
|
||||
runStateTTL, runStateMax := loadRunStatePolicyFromConfig(cfg.Agents.Defaults.RuntimeControl)
|
||||
// Keep compatibility with older env names.
|
||||
runStateTTL = envDuration("CLAWGO_RUN_STATE_TTL", runStateTTL)
|
||||
|
||||
@@ -3,17 +3,36 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const maxParallelToolCalls = 8
|
||||
const (
|
||||
maxParallelToolCallsLimit = 8
|
||||
defaultParallelToolCalls = 2
|
||||
)
|
||||
|
||||
type ParallelTool struct {
|
||||
registry *ToolRegistry
|
||||
type parallelCall struct {
|
||||
Index int
|
||||
Tool string
|
||||
Args map[string]interface{}
|
||||
ResultID string
|
||||
}
|
||||
|
||||
func NewParallelTool(registry *ToolRegistry) *ParallelTool {
|
||||
return &ParallelTool{registry: registry}
|
||||
type ParallelTool struct {
|
||||
registry *ToolRegistry
|
||||
maxParallelCalls int
|
||||
parallelSafe map[string]struct{}
|
||||
}
|
||||
|
||||
func NewParallelTool(registry *ToolRegistry, maxParallelCalls int, parallelSafe map[string]struct{}) *ParallelTool {
|
||||
limit := normalizeParallelLimit(maxParallelCalls)
|
||||
return &ParallelTool{
|
||||
registry: registry,
|
||||
maxParallelCalls: limit,
|
||||
parallelSafe: normalizeSafeToolNames(parallelSafe),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ParallelTool) Name() string {
|
||||
@@ -60,11 +79,7 @@ func (t *ParallelTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
return "", fmt.Errorf("calls must be an array")
|
||||
}
|
||||
|
||||
results := make(map[string]string)
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
sem := make(chan struct{}, maxParallelToolCalls)
|
||||
|
||||
calls := make([]parallelCall, 0, len(callsRaw))
|
||||
for i, c := range callsRaw {
|
||||
call, ok := c.(map[string]interface{})
|
||||
if !ok {
|
||||
@@ -78,31 +93,236 @@ func (t *ParallelTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
id = fmt.Sprintf("call_%d_%s", i, toolName)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(id, name string, args map[string]interface{}) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
|
||||
res, err := t.registry.Execute(ctx, name, args)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if err != nil {
|
||||
results[id] = fmt.Sprintf("Error: %v", err)
|
||||
} else {
|
||||
results[id] = res
|
||||
}
|
||||
}(id, toolName, toolArgs)
|
||||
calls = append(calls, parallelCall{
|
||||
Index: i,
|
||||
Tool: toolName,
|
||||
Args: toolArgs,
|
||||
ResultID: id,
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Simple string representation for the agent
|
||||
var output string
|
||||
for id, res := range results {
|
||||
output += fmt.Sprintf("--- Result for %s ---\n%s\n", id, res)
|
||||
if len(calls) == 0 {
|
||||
return "", fmt.Errorf("no valid calls provided")
|
||||
}
|
||||
|
||||
return output, nil
|
||||
if !t.callsParallelSafe(calls) {
|
||||
return t.executeSerial(ctx, calls), nil
|
||||
}
|
||||
|
||||
batches := buildParallelBatches(t.registry, calls)
|
||||
var output strings.Builder
|
||||
for _, batch := range batches {
|
||||
if len(batch) == 0 {
|
||||
continue
|
||||
}
|
||||
if len(batch) == 1 || t.maxParallelCalls <= 1 {
|
||||
output.WriteString(t.executeSerial(ctx, batch))
|
||||
continue
|
||||
}
|
||||
output.WriteString(t.executeParallel(ctx, batch))
|
||||
}
|
||||
return output.String(), nil
|
||||
}
|
||||
|
||||
func (t *ParallelTool) executeSerial(ctx context.Context, calls []parallelCall) string {
|
||||
results := make([]parallelResult, 0, len(calls))
|
||||
for _, call := range calls {
|
||||
res, err := t.registry.Execute(ctx, call.Tool, call.Args)
|
||||
results = append(results, parallelResult{
|
||||
Index: call.Index,
|
||||
ResultID: call.ResultID,
|
||||
Output: formatToolResult(res, err),
|
||||
})
|
||||
}
|
||||
return formatParallelResults(results)
|
||||
}
|
||||
|
||||
func (t *ParallelTool) executeParallel(ctx context.Context, calls []parallelCall) string {
|
||||
limit := t.maxParallelCalls
|
||||
if limit <= 0 {
|
||||
limit = defaultParallelToolCalls
|
||||
}
|
||||
if limit > len(calls) {
|
||||
limit = len(calls)
|
||||
}
|
||||
|
||||
results := make([]parallelResult, len(calls))
|
||||
var wg sync.WaitGroup
|
||||
sem := make(chan struct{}, limit)
|
||||
for i, call := range calls {
|
||||
wg.Add(1)
|
||||
sem <- struct{}{}
|
||||
go func(index int, call parallelCall) {
|
||||
defer wg.Done()
|
||||
defer func() { <-sem }()
|
||||
res, err := t.registry.Execute(ctx, call.Tool, call.Args)
|
||||
results[index] = parallelResult{
|
||||
Index: call.Index,
|
||||
ResultID: call.ResultID,
|
||||
Output: formatToolResult(res, err),
|
||||
}
|
||||
}(i, call)
|
||||
}
|
||||
wg.Wait()
|
||||
return formatParallelResults(results)
|
||||
}
|
||||
|
||||
func (t *ParallelTool) callsParallelSafe(calls []parallelCall) bool {
|
||||
for _, call := range calls {
|
||||
name := strings.ToLower(strings.TrimSpace(call.Tool))
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
if tool, ok := t.registry.Get(call.Tool); ok {
|
||||
if ps, ok := tool.(ParallelSafeTool); ok {
|
||||
if !ps.ParallelSafe() {
|
||||
return false
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if _, ok := t.parallelSafe[name]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func buildParallelBatches(registry *ToolRegistry, calls []parallelCall) [][]parallelCall {
|
||||
if len(calls) == 0 {
|
||||
return nil
|
||||
}
|
||||
batches := make([][]parallelCall, 0, len(calls))
|
||||
current := make([]parallelCall, 0, len(calls))
|
||||
used := map[string]struct{}{}
|
||||
|
||||
flush := func() {
|
||||
if len(current) == 0 {
|
||||
return
|
||||
}
|
||||
batch := append([]parallelCall(nil), current...)
|
||||
batches = append(batches, batch)
|
||||
current = current[:0]
|
||||
used = map[string]struct{}{}
|
||||
}
|
||||
|
||||
for _, call := range calls {
|
||||
keys := toolResourceKeys(registry, call.Tool, call.Args)
|
||||
if len(current) > 0 && hasResourceKeyConflict(used, keys) {
|
||||
flush()
|
||||
}
|
||||
current = append(current, call)
|
||||
for _, k := range keys {
|
||||
used[k] = struct{}{}
|
||||
}
|
||||
}
|
||||
flush()
|
||||
return batches
|
||||
}
|
||||
|
||||
func toolResourceKeys(registry *ToolRegistry, name string, args map[string]interface{}) []string {
|
||||
raw := strings.TrimSpace(name)
|
||||
lower := strings.ToLower(raw)
|
||||
if raw == "" || registry == nil {
|
||||
return nil
|
||||
}
|
||||
tool, ok := registry.Get(raw)
|
||||
if !ok && lower != raw {
|
||||
tool, ok = registry.Get(lower)
|
||||
}
|
||||
if !ok || tool == nil {
|
||||
return nil
|
||||
}
|
||||
rs, ok := tool.(ResourceScopedTool)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return normalizeResourceKeys(rs.ResourceKeys(args))
|
||||
}
|
||||
|
||||
func normalizeResourceKeys(keys []string) []string {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(keys))
|
||||
seen := make(map[string]struct{}, len(keys))
|
||||
for _, k := range keys {
|
||||
n := strings.ToLower(strings.TrimSpace(k))
|
||||
if n == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[n]; ok {
|
||||
continue
|
||||
}
|
||||
seen[n] = struct{}{}
|
||||
out = append(out, n)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func hasResourceKeyConflict(used map[string]struct{}, keys []string) bool {
|
||||
if len(keys) == 0 || len(used) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, k := range keys {
|
||||
if _, ok := used[k]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type parallelResult struct {
|
||||
Index int
|
||||
ResultID string
|
||||
Output string
|
||||
}
|
||||
|
||||
func formatParallelResults(results []parallelResult) string {
|
||||
sort.SliceStable(results, func(i, j int) bool {
|
||||
return results[i].Index < results[j].Index
|
||||
})
|
||||
var output strings.Builder
|
||||
for _, res := range results {
|
||||
if res.ResultID == "" {
|
||||
continue
|
||||
}
|
||||
output.WriteString(fmt.Sprintf("--- Result for %s ---\n%s\n", res.ResultID, res.Output))
|
||||
}
|
||||
return output.String()
|
||||
}
|
||||
|
||||
func formatToolResult(result string, err error) string {
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error: %v", err)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func normalizeParallelLimit(limit int) int {
|
||||
if limit <= 0 {
|
||||
limit = defaultParallelToolCalls
|
||||
}
|
||||
if limit < 1 {
|
||||
limit = 1
|
||||
}
|
||||
if limit > maxParallelToolCallsLimit {
|
||||
limit = maxParallelToolCallsLimit
|
||||
}
|
||||
return limit
|
||||
}
|
||||
|
||||
func normalizeSafeToolNames(names map[string]struct{}) map[string]struct{} {
|
||||
if len(names) == 0 {
|
||||
return map[string]struct{}{}
|
||||
}
|
||||
out := make(map[string]struct{}, len(names))
|
||||
for name := range names {
|
||||
n := strings.ToLower(strings.TrimSpace(name))
|
||||
if n == "" {
|
||||
continue
|
||||
}
|
||||
out[n] = struct{}{}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -3,17 +3,23 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const maxParallelFetchCalls = 8
|
||||
|
||||
type ParallelFetchTool struct {
|
||||
fetcher *WebFetchTool
|
||||
fetcher *WebFetchTool
|
||||
maxParallelCalls int
|
||||
parallelSafe map[string]struct{}
|
||||
}
|
||||
|
||||
func NewParallelFetchTool(fetcher *WebFetchTool) *ParallelFetchTool {
|
||||
return &ParallelFetchTool{fetcher: fetcher}
|
||||
func NewParallelFetchTool(fetcher *WebFetchTool, maxParallelCalls int, parallelSafe map[string]struct{}) *ParallelFetchTool {
|
||||
limit := normalizeParallelLimit(maxParallelCalls)
|
||||
return &ParallelFetchTool{
|
||||
fetcher: fetcher,
|
||||
maxParallelCalls: limit,
|
||||
parallelSafe: normalizeSafeToolNames(parallelSafe),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ParallelFetchTool) Name() string {
|
||||
@@ -46,20 +52,30 @@ func (t *ParallelFetchTool) Execute(ctx context.Context, args map[string]interfa
|
||||
return "", fmt.Errorf("urls must be an array")
|
||||
}
|
||||
|
||||
maxParallel := t.maxParallelCalls
|
||||
if maxParallel <= 1 {
|
||||
return t.executeSerial(ctx, urlsRaw), nil
|
||||
}
|
||||
|
||||
if !t.isParallelSafe() {
|
||||
return t.executeSerial(ctx, urlsRaw), nil
|
||||
}
|
||||
|
||||
results := make([]string, len(urlsRaw))
|
||||
var wg sync.WaitGroup
|
||||
sem := make(chan struct{}, maxParallelFetchCalls)
|
||||
sem := make(chan struct{}, minParallelLimit(maxParallel, len(urlsRaw)))
|
||||
|
||||
for i, u := range urlsRaw {
|
||||
urlStr, ok := u.(string)
|
||||
if !ok {
|
||||
results[i] = "Error: invalid url"
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
sem <- struct{}{}
|
||||
go func(index int, url string) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
|
||||
res, err := t.fetcher.Execute(ctx, map[string]interface{}{"url": url})
|
||||
@@ -73,10 +89,55 @@ func (t *ParallelFetchTool) Execute(ctx context.Context, args map[string]interfa
|
||||
|
||||
wg.Wait()
|
||||
|
||||
var output string
|
||||
for i, res := range results {
|
||||
output += fmt.Sprintf("=== Result %d ===\n%s\n\n", i+1, res)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
return formatFetchResults(results), nil
|
||||
}
|
||||
|
||||
func (t *ParallelFetchTool) executeSerial(ctx context.Context, urlsRaw []interface{}) string {
|
||||
results := make([]string, len(urlsRaw))
|
||||
for i, u := range urlsRaw {
|
||||
urlStr, ok := u.(string)
|
||||
if !ok {
|
||||
results[i] = "Error: invalid url"
|
||||
continue
|
||||
}
|
||||
res, err := t.fetcher.Execute(ctx, map[string]interface{}{"url": urlStr})
|
||||
if err != nil {
|
||||
results[i] = fmt.Sprintf("Error fetching %s: %v", urlStr, err)
|
||||
} else {
|
||||
results[i] = res
|
||||
}
|
||||
}
|
||||
return formatFetchResults(results)
|
||||
}
|
||||
|
||||
func (t *ParallelFetchTool) isParallelSafe() bool {
|
||||
if t.parallelSafe == nil {
|
||||
return false
|
||||
}
|
||||
if tool, ok := any(t.fetcher).(ParallelSafeTool); ok {
|
||||
return tool.ParallelSafe()
|
||||
}
|
||||
_, ok := t.parallelSafe["web_fetch"]
|
||||
return ok
|
||||
}
|
||||
|
||||
func formatFetchResults(results []string) string {
|
||||
var output strings.Builder
|
||||
for i, res := range results {
|
||||
output.WriteString(fmt.Sprintf("=== Result %d ===\n%s\n\n", i+1, res))
|
||||
}
|
||||
return output.String()
|
||||
}
|
||||
|
||||
func minParallelLimit(maxParallel, total int) int {
|
||||
if maxParallel <= 0 {
|
||||
return 1
|
||||
}
|
||||
if total <= 0 {
|
||||
return maxParallel
|
||||
}
|
||||
if maxParallel > total {
|
||||
return total
|
||||
}
|
||||
return maxParallel
|
||||
}
|
||||
|
||||
278
pkg/tools/parallel_test.go
Normal file
278
pkg/tools/parallel_test.go
Normal file
@@ -0,0 +1,278 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type basicTool struct {
|
||||
name string
|
||||
execute func(ctx context.Context, args map[string]interface{}) (string, error)
|
||||
}
|
||||
|
||||
func (t *basicTool) Name() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *basicTool) Description() string {
|
||||
return "test tool"
|
||||
}
|
||||
|
||||
func (t *basicTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{"type": "object"}
|
||||
}
|
||||
|
||||
func (t *basicTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
return t.execute(ctx, args)
|
||||
}
|
||||
|
||||
type safeTool struct {
|
||||
*basicTool
|
||||
}
|
||||
|
||||
func (t *safeTool) ParallelSafe() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type concurrencyTool struct {
|
||||
name string
|
||||
delay time.Duration
|
||||
current int32
|
||||
max int32
|
||||
}
|
||||
|
||||
func (t *concurrencyTool) Name() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *concurrencyTool) Description() string {
|
||||
return "concurrency test tool"
|
||||
}
|
||||
|
||||
func (t *concurrencyTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{"type": "object"}
|
||||
}
|
||||
|
||||
func (t *concurrencyTool) ParallelSafe() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *concurrencyTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
current := atomic.AddInt32(&t.current, 1)
|
||||
for {
|
||||
max := atomic.LoadInt32(&t.max)
|
||||
if current <= max {
|
||||
break
|
||||
}
|
||||
if atomic.CompareAndSwapInt32(&t.max, max, current) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(t.delay)
|
||||
atomic.AddInt32(&t.current, -1)
|
||||
return "ok", nil
|
||||
}
|
||||
|
||||
type conflictTool struct {
|
||||
name string
|
||||
delay time.Duration
|
||||
mu sync.Mutex
|
||||
active map[string]bool
|
||||
conflicts int32
|
||||
}
|
||||
|
||||
func (t *conflictTool) Name() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *conflictTool) Description() string {
|
||||
return "resource conflict test tool"
|
||||
}
|
||||
|
||||
func (t *conflictTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{"type": "object"}
|
||||
}
|
||||
|
||||
func (t *conflictTool) ParallelSafe() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *conflictTool) ResourceKeys(args map[string]interface{}) []string {
|
||||
key, _ := args["key"].(string)
|
||||
if key == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{key}
|
||||
}
|
||||
|
||||
func (t *conflictTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
key, _ := args["key"].(string)
|
||||
if key == "" {
|
||||
return "", errors.New("missing key")
|
||||
}
|
||||
defer func() {
|
||||
t.mu.Lock()
|
||||
delete(t.active, key)
|
||||
t.mu.Unlock()
|
||||
}()
|
||||
|
||||
t.mu.Lock()
|
||||
if t.active == nil {
|
||||
t.active = make(map[string]bool)
|
||||
}
|
||||
if t.active[key] {
|
||||
atomic.AddInt32(&t.conflicts, 1)
|
||||
}
|
||||
t.active[key] = true
|
||||
t.mu.Unlock()
|
||||
|
||||
time.Sleep(t.delay)
|
||||
return "ok", nil
|
||||
}
|
||||
|
||||
func TestParallelToolStableOrdering(t *testing.T) {
|
||||
registry := NewToolRegistry()
|
||||
tool := &safeTool{&basicTool{
|
||||
name: "echo",
|
||||
execute: func(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
delay := 0 * time.Millisecond
|
||||
switch v := args["delay"].(type) {
|
||||
case int:
|
||||
delay = time.Duration(v) * time.Millisecond
|
||||
case float64:
|
||||
delay = time.Duration(v) * time.Millisecond
|
||||
}
|
||||
if delay > 0 {
|
||||
time.Sleep(delay)
|
||||
}
|
||||
value, _ := args["value"].(string)
|
||||
return value, nil
|
||||
},
|
||||
}}
|
||||
registry.Register(tool)
|
||||
|
||||
parallel := NewParallelTool(registry, 3, nil)
|
||||
calls := []interface{}{
|
||||
map[string]interface{}{
|
||||
"tool": "echo",
|
||||
"arguments": map[string]interface{}{"value": "first", "delay": 40},
|
||||
"id": "first",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"tool": "echo",
|
||||
"arguments": map[string]interface{}{"value": "second", "delay": 10},
|
||||
"id": "second",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"tool": "echo",
|
||||
"arguments": map[string]interface{}{"value": "third", "delay": 20},
|
||||
"id": "third",
|
||||
},
|
||||
}
|
||||
|
||||
output, err := parallel.Execute(context.Background(), map[string]interface{}{"calls": calls})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
firstIdx := strings.Index(output, "Result for first")
|
||||
secondIdx := strings.Index(output, "Result for second")
|
||||
thirdIdx := strings.Index(output, "Result for third")
|
||||
if firstIdx == -1 || secondIdx == -1 || thirdIdx == -1 {
|
||||
t.Fatalf("missing result markers in output: %s", output)
|
||||
}
|
||||
if !(firstIdx < secondIdx && secondIdx < thirdIdx) {
|
||||
t.Fatalf("results not in call order: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParallelToolErrorFormatting(t *testing.T) {
|
||||
registry := NewToolRegistry()
|
||||
tool := &safeTool{&basicTool{
|
||||
name: "fail",
|
||||
execute: func(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
return "", errors.New("boom")
|
||||
},
|
||||
}}
|
||||
registry.Register(tool)
|
||||
|
||||
parallel := NewParallelTool(registry, 2, nil)
|
||||
calls := []interface{}{
|
||||
map[string]interface{}{
|
||||
"tool": "fail",
|
||||
"arguments": map[string]interface{}{},
|
||||
"id": "err",
|
||||
},
|
||||
}
|
||||
|
||||
output, err := parallel.Execute(context.Background(), map[string]interface{}{"calls": calls})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(output, "Error: boom") {
|
||||
t.Fatalf("expected formatted error, got: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParallelToolConcurrencyLimit(t *testing.T) {
|
||||
registry := NewToolRegistry()
|
||||
tool := &concurrencyTool{name: "sleep", delay: 25 * time.Millisecond}
|
||||
registry.Register(tool)
|
||||
|
||||
parallel := NewParallelTool(registry, 2, nil)
|
||||
calls := make([]interface{}, 5)
|
||||
for i := 0; i < len(calls); i++ {
|
||||
calls[i] = map[string]interface{}{
|
||||
"tool": "sleep",
|
||||
"arguments": map[string]interface{}{},
|
||||
"id": "call",
|
||||
}
|
||||
}
|
||||
|
||||
_, err := parallel.Execute(context.Background(), map[string]interface{}{"calls": calls})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if max := atomic.LoadInt32(&tool.max); max > 2 {
|
||||
t.Fatalf("expected max concurrency <= 2, got %d", max)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParallelToolResourceBatching(t *testing.T) {
|
||||
registry := NewToolRegistry()
|
||||
tool := &conflictTool{name: "resource", delay: 30 * time.Millisecond}
|
||||
registry.Register(tool)
|
||||
|
||||
parallel := NewParallelTool(registry, 3, nil)
|
||||
calls := []interface{}{
|
||||
map[string]interface{}{
|
||||
"tool": "resource",
|
||||
"arguments": map[string]interface{}{"key": "alpha"},
|
||||
"id": "first",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"tool": "resource",
|
||||
"arguments": map[string]interface{}{"key": "beta"},
|
||||
"id": "second",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"tool": "resource",
|
||||
"arguments": map[string]interface{}{"key": "alpha"},
|
||||
"id": "third",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := parallel.Execute(context.Background(), map[string]interface{}{"calls": calls})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if conflicts := atomic.LoadInt32(&tool.conflicts); conflicts > 0 {
|
||||
t.Fatalf("expected no resource conflicts, got %d", conflicts)
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"time"
|
||||
|
||||
"clawgo/pkg/logger"
|
||||
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -144,6 +143,10 @@ type WebFetchTool struct {
|
||||
maxChars int
|
||||
}
|
||||
|
||||
func (t *WebFetchTool) ParallelSafe() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func NewWebFetchTool(maxChars int) *WebFetchTool {
|
||||
if maxChars <= 0 {
|
||||
maxChars = 50000
|
||||
|
||||
Reference in New Issue
Block a user