mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-13 20:47:49 +08:00
243 lines
7.0 KiB
Go
243 lines
7.0 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"clawgo/pkg/providers"
|
|
"clawgo/pkg/session"
|
|
"clawgo/pkg/tools"
|
|
)
|
|
|
|
type replayScenario string
|
|
|
|
const (
|
|
replayDirectSuccess replayScenario = "direct_success"
|
|
replayOneToolSuccess replayScenario = "one_tool_success"
|
|
replayRepeatedToolCall replayScenario = "repeated_tool_call"
|
|
replayTransientFailure replayScenario = "transient_failure"
|
|
replayPermissionBlock replayScenario = "permission_block"
|
|
)
|
|
|
|
type replayProvider struct {
|
|
mu sync.Mutex
|
|
scenario replayScenario
|
|
planCalls int
|
|
reflectCalls int
|
|
finalizeCalls int
|
|
polishCalls int
|
|
totalCalls int
|
|
}
|
|
|
|
func (p *replayProvider) Chat(ctx context.Context, messages []providers.Message, defs []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
p.totalCalls++
|
|
|
|
last := ""
|
|
if len(messages) > 0 {
|
|
last = strings.TrimSpace(messages[len(messages)-1].Content)
|
|
}
|
|
|
|
// Phase-2 polish call.
|
|
if len(defs) == 0 && len(messages) > 0 && strings.Contains(strings.ToLower(strings.TrimSpace(messages[0].Content)), "rewrite the draft answer for end users") {
|
|
p.polishCalls++
|
|
return &providers.LLMResponse{Content: "polished final response"}, nil
|
|
}
|
|
|
|
// Reflection call.
|
|
if len(defs) == 0 && strings.Contains(last, "Classify current execution progress using JSON only.") {
|
|
p.reflectCalls++
|
|
switch p.scenario {
|
|
case replayTransientFailure:
|
|
return &providers.LLMResponse{Content: `{"decision":"continue","reason":"transient failures may recover","confidence":0.74}`}, nil
|
|
case replayRepeatedToolCall:
|
|
return &providers.LLMResponse{Content: `{"decision":"continue","reason":"need another attempt","confidence":0.72}`}, nil
|
|
default:
|
|
return &providers.LLMResponse{Content: `{"decision":"continue","reason":"default continue","confidence":0.60}`}, nil
|
|
}
|
|
}
|
|
|
|
// Finalization draft call.
|
|
if len(defs) == 0 {
|
|
p.finalizeCalls++
|
|
return &providers.LLMResponse{Content: "draft final response"}, nil
|
|
}
|
|
|
|
// Planning/tool-loop calls.
|
|
p.planCalls++
|
|
switch p.scenario {
|
|
case replayDirectSuccess:
|
|
return &providers.LLMResponse{Content: "direct completed"}, nil
|
|
case replayOneToolSuccess:
|
|
if p.planCalls == 1 {
|
|
return &providers.LLMResponse{
|
|
Content: "call one tool",
|
|
ToolCalls: []providers.ToolCall{
|
|
{ID: "tc-1", Name: "ok_tool", Arguments: map[string]interface{}{"x": 1}},
|
|
},
|
|
}, nil
|
|
}
|
|
return &providers.LLMResponse{Content: "task completed after one tool"}, nil
|
|
case replayRepeatedToolCall:
|
|
return &providers.LLMResponse{
|
|
Content: "repeating tool",
|
|
ToolCalls: []providers.ToolCall{
|
|
{ID: fmt.Sprintf("tc-r-%d", p.planCalls), Name: "ok_tool", Arguments: map[string]interface{}{"same": true}},
|
|
},
|
|
}, nil
|
|
case replayTransientFailure:
|
|
return &providers.LLMResponse{
|
|
Content: "transient fail tool",
|
|
ToolCalls: []providers.ToolCall{
|
|
{ID: fmt.Sprintf("tc-t-%d", p.planCalls), Name: "fail_tool_transient", Arguments: map[string]interface{}{}},
|
|
},
|
|
}, nil
|
|
case replayPermissionBlock:
|
|
return &providers.LLMResponse{
|
|
Content: "permission fail tool",
|
|
ToolCalls: []providers.ToolCall{
|
|
{ID: "tc-p-1", Name: "fail_tool_permission", Arguments: map[string]interface{}{}},
|
|
},
|
|
}, nil
|
|
default:
|
|
return &providers.LLMResponse{Content: "unexpected scenario"}, nil
|
|
}
|
|
}
|
|
|
|
func (p *replayProvider) GetDefaultModel() string { return "test-model" }
|
|
|
|
type replayToolImpl struct {
|
|
name string
|
|
run func(context.Context, map[string]interface{}) (string, error)
|
|
}
|
|
|
|
func (t replayToolImpl) Name() string { return t.name }
|
|
func (t replayToolImpl) Description() string { return "replay tool" }
|
|
func (t replayToolImpl) Parameters() map[string]interface{} {
|
|
return map[string]interface{}{
|
|
"type": "object",
|
|
"properties": map[string]interface{}{},
|
|
}
|
|
}
|
|
func (t replayToolImpl) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
|
return t.run(ctx, args)
|
|
}
|
|
|
|
type replayCaseResult struct {
|
|
name replayScenario
|
|
ok bool
|
|
iterations int
|
|
llmCalls int
|
|
reflectCalls int
|
|
}
|
|
|
|
func TestAgentLoopReplayBaseline(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
scenarios := []replayScenario{
|
|
replayDirectSuccess,
|
|
replayOneToolSuccess,
|
|
replayRepeatedToolCall,
|
|
replayTransientFailure,
|
|
replayPermissionBlock,
|
|
}
|
|
|
|
results := make([]replayCaseResult, 0, len(scenarios))
|
|
for _, sc := range scenarios {
|
|
sc := sc
|
|
t.Run(string(sc), func(t *testing.T) {
|
|
reg := tools.NewToolRegistry()
|
|
reg.Register(replayToolImpl{
|
|
name: "ok_tool",
|
|
run: func(ctx context.Context, args map[string]interface{}) (string, error) {
|
|
return "ok", nil
|
|
},
|
|
})
|
|
reg.Register(replayToolImpl{
|
|
name: "fail_tool_transient",
|
|
run: func(ctx context.Context, args map[string]interface{}) (string, error) {
|
|
return "", fmt.Errorf("temporary unavailable 503")
|
|
},
|
|
})
|
|
reg.Register(replayToolImpl{
|
|
name: "fail_tool_permission",
|
|
run: func(ctx context.Context, args map[string]interface{}) (string, error) {
|
|
return "", fmt.Errorf("permission denied")
|
|
},
|
|
})
|
|
|
|
provider := &replayProvider{scenario: sc}
|
|
al := &AgentLoop{
|
|
provider: provider,
|
|
providersByProxy: map[string]providers.LLMProvider{"proxy": provider},
|
|
modelsByProxy: map[string][]string{"proxy": {"test-model"}},
|
|
proxy: "proxy",
|
|
model: "test-model",
|
|
maxIterations: 6,
|
|
llmCallTimeout: 3 * time.Second,
|
|
tools: reg,
|
|
sessions: session.NewSessionManager(""),
|
|
workspace: t.TempDir(),
|
|
}
|
|
|
|
msgs := []providers.Message{
|
|
{Role: "system", Content: "you are a test agent"},
|
|
{Role: "user", Content: "complete task"},
|
|
}
|
|
|
|
out, iterations, err := al.runLLMToolLoop(context.Background(), msgs, "replay:"+string(sc), false, nil)
|
|
if err != nil {
|
|
t.Fatalf("runLLMToolLoop error: %v", err)
|
|
}
|
|
if strings.TrimSpace(out) == "" {
|
|
t.Fatalf("empty output")
|
|
}
|
|
results = append(results, replayCaseResult{
|
|
name: sc,
|
|
ok: true,
|
|
iterations: iterations,
|
|
llmCalls: provider.totalCalls,
|
|
reflectCalls: provider.reflectCalls,
|
|
})
|
|
})
|
|
}
|
|
|
|
total := len(results)
|
|
if total != len(scenarios) {
|
|
t.Fatalf("unexpected results count: %d", total)
|
|
}
|
|
success := 0
|
|
iterSum := 0
|
|
llmSum := 0
|
|
reflectSum := 0
|
|
for _, r := range results {
|
|
if r.ok {
|
|
success++
|
|
}
|
|
iterSum += r.iterations
|
|
llmSum += r.llmCalls
|
|
reflectSum += r.reflectCalls
|
|
}
|
|
successRate := float64(success) / float64(total)
|
|
avgIter := float64(iterSum) / float64(total)
|
|
avgLLM := float64(llmSum) / float64(total)
|
|
avgReflect := float64(reflectSum) / float64(total)
|
|
|
|
t.Logf("Replay baseline: success_rate=%.2f avg_iterations=%.2f avg_llm_calls=%.2f avg_reflect_calls=%.2f", successRate, avgIter, avgLLM, avgReflect)
|
|
|
|
if successRate < 1.0 {
|
|
t.Fatalf("expected all scenarios to succeed, got success_rate=%.2f", successRate)
|
|
}
|
|
if avgIter > 3.6 {
|
|
t.Fatalf("avg_iterations too high: %.2f", avgIter)
|
|
}
|
|
if avgLLM > 6.0 {
|
|
t.Fatalf("avg_llm_calls too high: %.2f", avgLLM)
|
|
}
|
|
}
|