mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-30 18:37:28 +08:00
tools: align parallel execution with runtime limits
This commit is contained in:
@@ -3,17 +3,23 @@ package tools
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const maxParallelFetchCalls = 8
|
||||
|
||||
type ParallelFetchTool struct {
|
||||
fetcher *WebFetchTool
|
||||
fetcher *WebFetchTool
|
||||
maxParallelCalls int
|
||||
parallelSafe map[string]struct{}
|
||||
}
|
||||
|
||||
func NewParallelFetchTool(fetcher *WebFetchTool) *ParallelFetchTool {
|
||||
return &ParallelFetchTool{fetcher: fetcher}
|
||||
func NewParallelFetchTool(fetcher *WebFetchTool, maxParallelCalls int, parallelSafe map[string]struct{}) *ParallelFetchTool {
|
||||
limit := normalizeParallelLimit(maxParallelCalls)
|
||||
return &ParallelFetchTool{
|
||||
fetcher: fetcher,
|
||||
maxParallelCalls: limit,
|
||||
parallelSafe: normalizeSafeToolNames(parallelSafe),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ParallelFetchTool) Name() string {
|
||||
@@ -46,20 +52,30 @@ func (t *ParallelFetchTool) Execute(ctx context.Context, args map[string]interfa
|
||||
return "", fmt.Errorf("urls must be an array")
|
||||
}
|
||||
|
||||
maxParallel := t.maxParallelCalls
|
||||
if maxParallel <= 1 {
|
||||
return t.executeSerial(ctx, urlsRaw), nil
|
||||
}
|
||||
|
||||
if !t.isParallelSafe() {
|
||||
return t.executeSerial(ctx, urlsRaw), nil
|
||||
}
|
||||
|
||||
results := make([]string, len(urlsRaw))
|
||||
var wg sync.WaitGroup
|
||||
sem := make(chan struct{}, maxParallelFetchCalls)
|
||||
sem := make(chan struct{}, minParallelLimit(maxParallel, len(urlsRaw)))
|
||||
|
||||
for i, u := range urlsRaw {
|
||||
urlStr, ok := u.(string)
|
||||
if !ok {
|
||||
results[i] = "Error: invalid url"
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
sem <- struct{}{}
|
||||
go func(index int, url string) {
|
||||
defer wg.Done()
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
|
||||
res, err := t.fetcher.Execute(ctx, map[string]interface{}{"url": url})
|
||||
@@ -73,10 +89,55 @@ func (t *ParallelFetchTool) Execute(ctx context.Context, args map[string]interfa
|
||||
|
||||
wg.Wait()
|
||||
|
||||
var output string
|
||||
for i, res := range results {
|
||||
output += fmt.Sprintf("=== Result %d ===\n%s\n\n", i+1, res)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
return formatFetchResults(results), nil
|
||||
}
|
||||
|
||||
func (t *ParallelFetchTool) executeSerial(ctx context.Context, urlsRaw []interface{}) string {
|
||||
results := make([]string, len(urlsRaw))
|
||||
for i, u := range urlsRaw {
|
||||
urlStr, ok := u.(string)
|
||||
if !ok {
|
||||
results[i] = "Error: invalid url"
|
||||
continue
|
||||
}
|
||||
res, err := t.fetcher.Execute(ctx, map[string]interface{}{"url": urlStr})
|
||||
if err != nil {
|
||||
results[i] = fmt.Sprintf("Error fetching %s: %v", urlStr, err)
|
||||
} else {
|
||||
results[i] = res
|
||||
}
|
||||
}
|
||||
return formatFetchResults(results)
|
||||
}
|
||||
|
||||
func (t *ParallelFetchTool) isParallelSafe() bool {
|
||||
if t.parallelSafe == nil {
|
||||
return false
|
||||
}
|
||||
if tool, ok := any(t.fetcher).(ParallelSafeTool); ok {
|
||||
return tool.ParallelSafe()
|
||||
}
|
||||
_, ok := t.parallelSafe["web_fetch"]
|
||||
return ok
|
||||
}
|
||||
|
||||
func formatFetchResults(results []string) string {
|
||||
var output strings.Builder
|
||||
for i, res := range results {
|
||||
output.WriteString(fmt.Sprintf("=== Result %d ===\n%s\n\n", i+1, res))
|
||||
}
|
||||
return output.String()
|
||||
}
|
||||
|
||||
func minParallelLimit(maxParallel, total int) int {
|
||||
if maxParallel <= 0 {
|
||||
return 1
|
||||
}
|
||||
if total <= 0 {
|
||||
return maxParallel
|
||||
}
|
||||
if maxParallel > total {
|
||||
return total
|
||||
}
|
||||
return maxParallel
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user