Files
clawgo/pkg/agent/loop_replay_baseline_test.go
2026-02-19 15:23:25 +08:00

243 lines
7.0 KiB
Go

package agent
import (
"context"
"fmt"
"strings"
"sync"
"testing"
"time"
"clawgo/pkg/providers"
"clawgo/pkg/session"
"clawgo/pkg/tools"
)
type replayScenario string
const (
replayDirectSuccess replayScenario = "direct_success"
replayOneToolSuccess replayScenario = "one_tool_success"
replayRepeatedToolCall replayScenario = "repeated_tool_call"
replayTransientFailure replayScenario = "transient_failure"
replayPermissionBlock replayScenario = "permission_block"
)
type replayProvider struct {
mu sync.Mutex
scenario replayScenario
planCalls int
reflectCalls int
finalizeCalls int
polishCalls int
totalCalls int
}
func (p *replayProvider) Chat(ctx context.Context, messages []providers.Message, defs []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
p.mu.Lock()
defer p.mu.Unlock()
p.totalCalls++
last := ""
if len(messages) > 0 {
last = strings.TrimSpace(messages[len(messages)-1].Content)
}
// Phase-2 polish call.
if len(defs) == 0 && len(messages) > 0 && strings.Contains(strings.ToLower(strings.TrimSpace(messages[0].Content)), "rewrite the draft answer for end users") {
p.polishCalls++
return &providers.LLMResponse{Content: "polished final response"}, nil
}
// Reflection call.
if len(defs) == 0 && strings.Contains(last, "Classify current execution progress using JSON only.") {
p.reflectCalls++
switch p.scenario {
case replayTransientFailure:
return &providers.LLMResponse{Content: `{"decision":"continue","reason":"transient failures may recover","confidence":0.74}`}, nil
case replayRepeatedToolCall:
return &providers.LLMResponse{Content: `{"decision":"continue","reason":"need another attempt","confidence":0.72}`}, nil
default:
return &providers.LLMResponse{Content: `{"decision":"continue","reason":"default continue","confidence":0.60}`}, nil
}
}
// Finalization draft call.
if len(defs) == 0 {
p.finalizeCalls++
return &providers.LLMResponse{Content: "draft final response"}, nil
}
// Planning/tool-loop calls.
p.planCalls++
switch p.scenario {
case replayDirectSuccess:
return &providers.LLMResponse{Content: "direct completed"}, nil
case replayOneToolSuccess:
if p.planCalls == 1 {
return &providers.LLMResponse{
Content: "call one tool",
ToolCalls: []providers.ToolCall{
{ID: "tc-1", Name: "ok_tool", Arguments: map[string]interface{}{"x": 1}},
},
}, nil
}
return &providers.LLMResponse{Content: "task completed after one tool"}, nil
case replayRepeatedToolCall:
return &providers.LLMResponse{
Content: "repeating tool",
ToolCalls: []providers.ToolCall{
{ID: fmt.Sprintf("tc-r-%d", p.planCalls), Name: "ok_tool", Arguments: map[string]interface{}{"same": true}},
},
}, nil
case replayTransientFailure:
return &providers.LLMResponse{
Content: "transient fail tool",
ToolCalls: []providers.ToolCall{
{ID: fmt.Sprintf("tc-t-%d", p.planCalls), Name: "fail_tool_transient", Arguments: map[string]interface{}{}},
},
}, nil
case replayPermissionBlock:
return &providers.LLMResponse{
Content: "permission fail tool",
ToolCalls: []providers.ToolCall{
{ID: "tc-p-1", Name: "fail_tool_permission", Arguments: map[string]interface{}{}},
},
}, nil
default:
return &providers.LLMResponse{Content: "unexpected scenario"}, nil
}
}
func (p *replayProvider) GetDefaultModel() string { return "test-model" }
type replayToolImpl struct {
name string
run func(context.Context, map[string]interface{}) (string, error)
}
func (t replayToolImpl) Name() string { return t.name }
func (t replayToolImpl) Description() string { return "replay tool" }
func (t replayToolImpl) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
}
}
func (t replayToolImpl) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
return t.run(ctx, args)
}
type replayCaseResult struct {
name replayScenario
ok bool
iterations int
llmCalls int
reflectCalls int
}
func TestAgentLoopReplayBaseline(t *testing.T) {
t.Parallel()
scenarios := []replayScenario{
replayDirectSuccess,
replayOneToolSuccess,
replayRepeatedToolCall,
replayTransientFailure,
replayPermissionBlock,
}
results := make([]replayCaseResult, 0, len(scenarios))
for _, sc := range scenarios {
sc := sc
t.Run(string(sc), func(t *testing.T) {
reg := tools.NewToolRegistry()
reg.Register(replayToolImpl{
name: "ok_tool",
run: func(ctx context.Context, args map[string]interface{}) (string, error) {
return "ok", nil
},
})
reg.Register(replayToolImpl{
name: "fail_tool_transient",
run: func(ctx context.Context, args map[string]interface{}) (string, error) {
return "", fmt.Errorf("temporary unavailable 503")
},
})
reg.Register(replayToolImpl{
name: "fail_tool_permission",
run: func(ctx context.Context, args map[string]interface{}) (string, error) {
return "", fmt.Errorf("permission denied")
},
})
provider := &replayProvider{scenario: sc}
al := &AgentLoop{
provider: provider,
providersByProxy: map[string]providers.LLMProvider{"proxy": provider},
modelsByProxy: map[string][]string{"proxy": {"test-model"}},
proxy: "proxy",
model: "test-model",
maxIterations: 6,
llmCallTimeout: 3 * time.Second,
tools: reg,
sessions: session.NewSessionManager(""),
workspace: t.TempDir(),
}
msgs := []providers.Message{
{Role: "system", Content: "you are a test agent"},
{Role: "user", Content: "complete task"},
}
out, iterations, err := al.runLLMToolLoop(context.Background(), msgs, "replay:"+string(sc), false, nil)
if err != nil {
t.Fatalf("runLLMToolLoop error: %v", err)
}
if strings.TrimSpace(out) == "" {
t.Fatalf("empty output")
}
results = append(results, replayCaseResult{
name: sc,
ok: true,
iterations: iterations,
llmCalls: provider.totalCalls,
reflectCalls: provider.reflectCalls,
})
})
}
total := len(results)
if total != len(scenarios) {
t.Fatalf("unexpected results count: %d", total)
}
success := 0
iterSum := 0
llmSum := 0
reflectSum := 0
for _, r := range results {
if r.ok {
success++
}
iterSum += r.iterations
llmSum += r.llmCalls
reflectSum += r.reflectCalls
}
successRate := float64(success) / float64(total)
avgIter := float64(iterSum) / float64(total)
avgLLM := float64(llmSum) / float64(total)
avgReflect := float64(reflectSum) / float64(total)
t.Logf("Replay baseline: success_rate=%.2f avg_iterations=%.2f avg_llm_calls=%.2f avg_reflect_calls=%.2f", successRate, avgIter, avgLLM, avgReflect)
if successRate < 1.0 {
t.Fatalf("expected all scenarios to succeed, got success_rate=%.2f", successRate)
}
if avgIter > 3.6 {
t.Fatalf("avg_iterations too high: %.2f", avgIter)
}
if avgLLM > 6.0 {
t.Fatalf("avg_llm_calls too high: %.2f", avgLLM)
}
}