mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-13 04:27:28 +08:00
194 lines
6.0 KiB
Go
194 lines
6.0 KiB
Go
package agent
|
||
|
||
import (
|
||
"context"
|
||
"os"
|
||
"path/filepath"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"testing"
|
||
"time"
|
||
|
||
"clawgo/pkg/bus"
|
||
"clawgo/pkg/config"
|
||
"clawgo/pkg/ekg"
|
||
"clawgo/pkg/providers"
|
||
)
|
||
|
||
func TestSplitPlannedSegments_Bullets(t *testing.T) {
|
||
parts := splitPlannedSegments("- 修复 a.go\n- 补充 b.go 测试")
|
||
if len(parts) != 2 {
|
||
t.Fatalf("unexpected parts: %#v", parts)
|
||
}
|
||
}
|
||
|
||
func TestPlanSessionTasks_Semicolon(t *testing.T) {
|
||
loop := &AgentLoop{}
|
||
tasks := loop.planSessionTasks(bus.InboundMessage{Channel: "cli", Content: "修复 pkg/a.go;修复 pkg/b.go"})
|
||
if len(tasks) != 2 {
|
||
t.Fatalf("expected 2 tasks, got %#v", tasks)
|
||
}
|
||
if tasks[0].Content == tasks[1].Content {
|
||
t.Fatalf("expected distinct tasks: %#v", tasks)
|
||
}
|
||
}
|
||
|
||
func TestProcessPlannedMessage_AggregatesResults(t *testing.T) {
|
||
rp := &recordingProvider{responses: []providers.LLMResponse{
|
||
{Content: "done-a", FinishReason: "stop"},
|
||
{Content: "done-b", FinishReason: "stop"},
|
||
}}
|
||
loop := setupLoop(t, rp)
|
||
|
||
resp, err := loop.processPlannedMessage(context.Background(), bus.InboundMessage{
|
||
Channel: "cli",
|
||
SenderID: "u",
|
||
ChatID: "direct",
|
||
SessionKey: "sess-plan",
|
||
Content: "修复 pkg/a.go;补充 pkg/b.go 测试",
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("processPlannedMessage error: %v", err)
|
||
}
|
||
if len(rp.calls) != 2 {
|
||
t.Fatalf("expected 2 provider calls, got %d", len(rp.calls))
|
||
}
|
||
if resp == "" {
|
||
t.Fatalf("expected aggregate response")
|
||
}
|
||
}
|
||
|
||
type probeProvider struct {
|
||
mu sync.Mutex
|
||
inFlight int
|
||
maxInFlight int
|
||
delayPerCall time.Duration
|
||
responseCount int
|
||
}
|
||
|
||
func (p *probeProvider) Chat(_ context.Context, _ []providers.Message, _ []providers.ToolDefinition, _ string, _ map[string]interface{}) (*providers.LLMResponse, error) {
|
||
p.mu.Lock()
|
||
p.inFlight++
|
||
if p.inFlight > p.maxInFlight {
|
||
p.maxInFlight = p.inFlight
|
||
}
|
||
p.responseCount++
|
||
p.mu.Unlock()
|
||
|
||
time.Sleep(p.delayPerCall)
|
||
|
||
p.mu.Lock()
|
||
n := p.responseCount
|
||
p.inFlight--
|
||
p.mu.Unlock()
|
||
resp := providers.LLMResponse{Content: "done-" + strconv.Itoa(n), FinishReason: "stop"}
|
||
return &resp, nil
|
||
}
|
||
|
||
func (p *probeProvider) GetDefaultModel() string { return "test-model" }
|
||
|
||
func TestRunPlannedTasks_NonConflictingKeysCanRunInParallel(t *testing.T) {
|
||
p := &probeProvider{delayPerCall: 100 * time.Millisecond}
|
||
cfg := config.DefaultConfig()
|
||
cfg.Agents.Defaults.Workspace = filepath.Join(t.TempDir(), "workspace")
|
||
cfg.Agents.Defaults.MaxToolIterations = 2
|
||
cfg.Agents.Defaults.ContextCompaction.Enabled = false
|
||
loop := NewAgentLoop(cfg, bus.NewMessageBus(), p, nil)
|
||
|
||
_, err := loop.processPlannedMessage(context.Background(), bus.InboundMessage{
|
||
Channel: "cli",
|
||
SenderID: "u",
|
||
ChatID: "direct",
|
||
SessionKey: "sess-plan-parallel",
|
||
Content: "[resource_keys: file:pkg/a.go] 修复 a;[resource_keys: file:pkg/b.go] 修复 b",
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("processPlannedMessage error: %v", err)
|
||
}
|
||
if p.maxInFlight < 2 {
|
||
t.Fatalf("expected parallel execution for non-conflicting keys, got maxInFlight=%d", p.maxInFlight)
|
||
}
|
||
}
|
||
|
||
func TestRunPlannedTasks_ConflictingKeysMutuallyExclusive(t *testing.T) {
|
||
p := &probeProvider{delayPerCall: 100 * time.Millisecond}
|
||
cfg := config.DefaultConfig()
|
||
cfg.Agents.Defaults.Workspace = filepath.Join(t.TempDir(), "workspace")
|
||
cfg.Agents.Defaults.MaxToolIterations = 2
|
||
cfg.Agents.Defaults.ContextCompaction.Enabled = false
|
||
loop := NewAgentLoop(cfg, bus.NewMessageBus(), p, nil)
|
||
|
||
_, err := loop.processPlannedMessage(context.Background(), bus.InboundMessage{
|
||
Channel: "cli",
|
||
SenderID: "u",
|
||
ChatID: "direct",
|
||
SessionKey: "sess-plan-locked",
|
||
Content: "[resource_keys: file:pkg/a.go] 修复 a;[resource_keys: file:pkg/a.go] 补测试",
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("processPlannedMessage error: %v", err)
|
||
}
|
||
if p.maxInFlight != 1 {
|
||
t.Fatalf("expected mutual exclusion for conflicting keys, got maxInFlight=%d", p.maxInFlight)
|
||
}
|
||
}
|
||
|
||
func TestRunPlannedTasks_PublishesStepProgress(t *testing.T) {
|
||
rp := &recordingProvider{responses: []providers.LLMResponse{
|
||
{Content: "done-a", FinishReason: "stop"},
|
||
{Content: "done-b", FinishReason: "stop"},
|
||
}}
|
||
loop := setupLoop(t, rp)
|
||
|
||
_, err := loop.processPlannedMessage(context.Background(), bus.InboundMessage{
|
||
Channel: "cli",
|
||
SenderID: "u",
|
||
ChatID: "direct",
|
||
SessionKey: "sess-plan-progress",
|
||
Content: "修复 pkg/a.go;补充 pkg/b.go 测试",
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("processPlannedMessage error: %v", err)
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||
defer cancel()
|
||
out1, ok := loop.bus.SubscribeOutbound(ctx)
|
||
if !ok {
|
||
t.Fatalf("expected first progress outbound")
|
||
}
|
||
out2, ok := loop.bus.SubscribeOutbound(ctx)
|
||
if !ok {
|
||
t.Fatalf("expected second progress outbound")
|
||
}
|
||
all := out1.Content + "\n" + out2.Content
|
||
if !strings.Contains(all, "进度 1/2") || !strings.Contains(all, "进度 2/2") {
|
||
t.Fatalf("unexpected progress outputs:\n%s", all)
|
||
}
|
||
}
|
||
|
||
func TestFindRecentRelatedErrorEvent(t *testing.T) {
|
||
ws := filepath.Join(t.TempDir(), "workspace")
|
||
_ = os.MkdirAll(filepath.Join(ws, "memory"), 0o755)
|
||
line := `{"task_id":"t1","status":"error","log":"open /tmp/a.go failed","input_preview":"修复 pkg/a.go 的读取错误","source":"direct","channel":"cli"}`
|
||
if err := os.WriteFile(filepath.Join(ws, "memory", "task-audit.jsonl"), []byte(line+"\n"), 0o644); err != nil {
|
||
t.Fatalf("write audit: %v", err)
|
||
}
|
||
loop := &AgentLoop{workspace: ws, ekg: ekg.New(ws)}
|
||
loop.ekg.Record(ekg.Event{TaskID: "t1", Status: "error", Log: "open /tmp/a.go failed"})
|
||
loop.ekg.Record(ekg.Event{TaskID: "t1", Status: "error", Log: "open /tmp/a.go failed"})
|
||
loop.ekg.Record(ekg.Event{TaskID: "t1", Status: "error", Log: "open /tmp/a.go failed"})
|
||
|
||
ev, ok := loop.findRecentRelatedErrorEvent("请修复 pkg/a.go 的读取问题")
|
||
if !ok {
|
||
t.Fatalf("expected matched recent error event")
|
||
}
|
||
if ev.TaskID != "t1" {
|
||
t.Fatalf("unexpected task id: %s", ev.TaskID)
|
||
}
|
||
if hint := loop.ekgHintForTask(plannedTask{Content: "修复 pkg/a.go"}); hint == "" {
|
||
t.Fatalf("expected non-empty ekg hint")
|
||
}
|
||
}
|