mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-15 13:37:45 +08:00
349 lines
11 KiB
Go
349 lines
11 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/YspCoder/clawgo/pkg/bus"
|
|
"github.com/YspCoder/clawgo/pkg/lifecycle"
|
|
"github.com/YspCoder/clawgo/pkg/providers"
|
|
"github.com/YspCoder/clawgo/pkg/session"
|
|
toolspkg "github.com/YspCoder/clawgo/pkg/tools"
|
|
)
|
|
|
|
type pressureProvider struct {
|
|
tokens int
|
|
}
|
|
|
|
func (p *pressureProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
|
|
return &providers.LLMResponse{Content: "Key Facts\n- compacted", FinishReason: "stop"}, nil
|
|
}
|
|
|
|
func (p *pressureProvider) GetDefaultModel() string { return "pressure-model" }
|
|
|
|
func (p *pressureProvider) CountTokens(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.UsageInfo, error) {
|
|
return &providers.UsageInfo{PromptTokens: p.tokens, TotalTokens: p.tokens}, nil
|
|
}
|
|
|
|
type fallbackStreamingProvider struct {
|
|
stream func(ctx context.Context, onDelta func(string)) (*providers.LLMResponse, error)
|
|
chat func(ctx context.Context) (*providers.LLMResponse, error)
|
|
}
|
|
|
|
func (p *fallbackStreamingProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
|
|
return p.chat(ctx)
|
|
}
|
|
|
|
func (p *fallbackStreamingProvider) ChatStream(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*providers.LLMResponse, error) {
|
|
return p.stream(ctx, onDelta)
|
|
}
|
|
|
|
func (p *fallbackStreamingProvider) GetDefaultModel() string { return "stream-model" }
|
|
|
|
type asyncCompactionProvider struct {
|
|
mu sync.Mutex
|
|
started chan int
|
|
release chan struct{}
|
|
finished chan int
|
|
calls int
|
|
}
|
|
|
|
func (p *asyncCompactionProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
|
|
p.mu.Lock()
|
|
p.calls++
|
|
call := p.calls
|
|
p.mu.Unlock()
|
|
if p.started != nil {
|
|
p.started <- call
|
|
}
|
|
if p.release != nil {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-p.release:
|
|
}
|
|
}
|
|
if p.finished != nil {
|
|
p.finished <- call
|
|
}
|
|
return &providers.LLMResponse{Content: "Key Facts\n- compacted", FinishReason: "stop"}, nil
|
|
}
|
|
|
|
func (p *asyncCompactionProvider) GetDefaultModel() string { return "async-model" }
|
|
|
|
func TestCompactSessionTriggeredByTokenPressure(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sm := session.NewSessionManager(t.TempDir())
|
|
key := "cli:pressure"
|
|
for _, content := range []string{"one", "two", "three", "four", "five", "six"} {
|
|
sm.AddMessage(key, "user", content)
|
|
}
|
|
provider := &pressureProvider{tokens: 900}
|
|
loop := &AgentLoop{
|
|
provider: provider,
|
|
model: provider.GetDefaultModel(),
|
|
maxTokens: 1000,
|
|
providerNames: []string{"pressure"},
|
|
sessions: sm,
|
|
compactionEnabled: true,
|
|
compactionTrigger: 100,
|
|
compactionProtectLastN: 2,
|
|
compactionKeepRecent: 2,
|
|
compactionTargetRatio: 0.35,
|
|
compactionPressureThreshold: 0.8,
|
|
compactionMaxSummaryChars: 6000,
|
|
compactionMaxTranscriptChars: 20000,
|
|
}
|
|
|
|
applied, _, _ := loop.compactSessionIfNeeded(context.Background(), key)
|
|
if !applied {
|
|
t.Fatal("expected compaction to apply")
|
|
}
|
|
|
|
history := sm.GetPromptHistory(key)
|
|
if len(history) != 3 {
|
|
t.Fatalf("expected ratio-based keep count 3, got %d", len(history))
|
|
}
|
|
if history[0].Content != "four" || history[2].Content != "six" {
|
|
t.Fatalf("expected tail messages preserved, got %#v", history)
|
|
}
|
|
if summary := sm.GetSummary(key); summary == "" {
|
|
t.Fatal("expected compaction summary to be written")
|
|
}
|
|
}
|
|
|
|
func TestFinalizeUserMessageDoesNotWaitForCompaction(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sm := session.NewSessionManager(t.TempDir())
|
|
key := "cli:async"
|
|
for _, content := range []string{"one", "two", "three", "four", "five", "six"} {
|
|
sm.AddMessage(key, "user", content)
|
|
}
|
|
provider := &asyncCompactionProvider{
|
|
started: make(chan int, 2),
|
|
release: make(chan struct{}),
|
|
}
|
|
loop := &AgentLoop{
|
|
provider: provider,
|
|
model: provider.GetDefaultModel(),
|
|
sessions: sm,
|
|
compactionEnabled: true,
|
|
compactionTrigger: 4,
|
|
compactionProtectLastN: 2,
|
|
compactionKeepRecent: 2,
|
|
compactionTargetRatio: 0.35,
|
|
compactionPressureThreshold: 0.1,
|
|
compactionMaxSummaryChars: 6000,
|
|
compactionMaxTranscriptChars: 20000,
|
|
compactionRunner: lifecycle.NewLoopRunner(),
|
|
compactionSignal: make(chan struct{}, 1),
|
|
compactionQueued: map[string]struct{}{},
|
|
compactionInflight: map[string]struct{}{},
|
|
compactionDirty: map[string]struct{}{},
|
|
}
|
|
t.Cleanup(loop.Stop)
|
|
|
|
start := time.Now()
|
|
loop.finalizeUserMessage(key, "en", nil, "final")
|
|
if elapsed := time.Since(start); elapsed > 150*time.Millisecond {
|
|
t.Fatalf("expected finalizeUserMessage to return quickly, took %s", elapsed)
|
|
}
|
|
select {
|
|
case <-provider.started:
|
|
case <-time.After(500 * time.Millisecond):
|
|
t.Fatal("expected async compaction to start in background")
|
|
}
|
|
close(provider.release)
|
|
deadline := time.Now().Add(2 * time.Second)
|
|
for time.Now().Before(deadline) {
|
|
if summary := sm.GetSummary(key); summary != "" {
|
|
return
|
|
}
|
|
time.Sleep(20 * time.Millisecond)
|
|
}
|
|
t.Fatal("expected async compaction summary to be written")
|
|
}
|
|
|
|
func TestCompactionWorkerRetriesDirtySessionWithoutLosingNewMessages(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sm := session.NewSessionManager(t.TempDir())
|
|
key := "cli:dirty"
|
|
for _, content := range []string{"one", "two", "three", "four", "five", "six"} {
|
|
sm.AddMessage(key, "user", content)
|
|
}
|
|
provider := &asyncCompactionProvider{
|
|
started: make(chan int, 4),
|
|
release: make(chan struct{}, 4),
|
|
finished: make(chan int, 4),
|
|
}
|
|
loop := &AgentLoop{
|
|
provider: provider,
|
|
model: provider.GetDefaultModel(),
|
|
sessions: sm,
|
|
compactionEnabled: true,
|
|
compactionTrigger: 4,
|
|
compactionProtectLastN: 2,
|
|
compactionKeepRecent: 2,
|
|
compactionTargetRatio: 0.35,
|
|
compactionPressureThreshold: 0.1,
|
|
compactionMaxSummaryChars: 6000,
|
|
compactionMaxTranscriptChars: 20000,
|
|
compactionRunner: lifecycle.NewLoopRunner(),
|
|
compactionSignal: make(chan struct{}, 1),
|
|
compactionQueued: map[string]struct{}{},
|
|
compactionInflight: map[string]struct{}{},
|
|
compactionDirty: map[string]struct{}{},
|
|
}
|
|
t.Cleanup(loop.Stop)
|
|
|
|
loop.enqueueSessionCompaction(key)
|
|
select {
|
|
case <-provider.started:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("expected first compaction run to start")
|
|
}
|
|
|
|
sm.AddMessage(key, "assistant", "seven")
|
|
loop.enqueueSessionCompaction(key)
|
|
provider.release <- struct{}{}
|
|
|
|
select {
|
|
case <-provider.finished:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("expected first compaction run to finish")
|
|
}
|
|
select {
|
|
case call := <-provider.started:
|
|
if call != 2 {
|
|
t.Fatalf("expected second compaction attempt after dirty retry, got call %d", call)
|
|
}
|
|
case <-time.After(time.Second):
|
|
t.Fatal("expected dirty session to trigger a second compaction run")
|
|
}
|
|
provider.release <- struct{}{}
|
|
select {
|
|
case <-provider.finished:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("expected second compaction run to finish")
|
|
}
|
|
|
|
deadline := time.Now().Add(2 * time.Second)
|
|
for time.Now().Before(deadline) {
|
|
history := sm.GetPromptHistory(key)
|
|
if len(history) > 0 && history[len(history)-1].Content == "seven" && sm.GetSummary(key) != "" {
|
|
return
|
|
}
|
|
time.Sleep(20 * time.Millisecond)
|
|
}
|
|
t.Fatal("expected retried compaction to preserve new message and summary")
|
|
}
|
|
|
|
func TestRequestStreamingLLMResponseFallsBackBeforeFirstDelta(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
provider := &fallbackStreamingProvider{
|
|
stream: func(ctx context.Context, onDelta func(string)) (*providers.LLMResponse, error) {
|
|
<-ctx.Done()
|
|
return nil, providers.NewProviderExecutionError("stream_stale", "stream stale", "stream", true, "test")
|
|
},
|
|
chat: func(ctx context.Context) (*providers.LLMResponse, error) {
|
|
return &providers.LLMResponse{Content: "fallback", FinishReason: "stop"}, nil
|
|
},
|
|
}
|
|
loop := &AgentLoop{
|
|
bus: bus.NewMessageBus(),
|
|
sessionStreamed: map[string]bool{},
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(), 400*time.Millisecond)
|
|
defer cancel()
|
|
|
|
resp, attempts, err := loop.requestStreamingLLMResponse(llmTurnLoopConfig{
|
|
ctx: ctx,
|
|
sessionKey: "cli:test",
|
|
toolChannel: "telegram",
|
|
toolChatID: "chat",
|
|
enableStreaming: true,
|
|
}, provider, provider, provider.GetDefaultModel(), []providers.Message{{Role: "user", Content: "hello"}}, nil, nil)
|
|
if err != nil {
|
|
t.Fatalf("expected fallback success, got %v", err)
|
|
}
|
|
if attempts != 2 {
|
|
t.Fatalf("expected streaming + fallback attempts, got %d", attempts)
|
|
}
|
|
if resp == nil || resp.Content != "fallback" {
|
|
t.Fatalf("unexpected fallback response: %#v", resp)
|
|
}
|
|
}
|
|
|
|
func TestRequestStreamingLLMResponseDoesNotFallbackAfterDelta(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
provider := &fallbackStreamingProvider{
|
|
stream: func(ctx context.Context, onDelta func(string)) (*providers.LLMResponse, error) {
|
|
onDelta("partial")
|
|
return nil, providers.NewProviderExecutionError("stream_failed", "stream failed", "stream", true, "test")
|
|
},
|
|
chat: func(ctx context.Context) (*providers.LLMResponse, error) {
|
|
return &providers.LLMResponse{Content: "fallback", FinishReason: "stop"}, nil
|
|
},
|
|
}
|
|
loop := &AgentLoop{
|
|
bus: bus.NewMessageBus(),
|
|
sessionStreamed: map[string]bool{},
|
|
}
|
|
|
|
resp, attempts, err := loop.requestStreamingLLMResponse(llmTurnLoopConfig{
|
|
ctx: context.Background(),
|
|
sessionKey: "cli:test",
|
|
toolChannel: "telegram",
|
|
toolChatID: "chat",
|
|
enableStreaming: true,
|
|
}, provider, provider, provider.GetDefaultModel(), []providers.Message{{Role: "user", Content: "hello"}}, nil, nil)
|
|
if err == nil {
|
|
t.Fatal("expected stream failure without fallback")
|
|
}
|
|
if attempts != 1 {
|
|
t.Fatalf("expected single streaming attempt, got %d", attempts)
|
|
}
|
|
if resp != nil {
|
|
t.Fatalf("expected nil response on post-delta stream failure, got %#v", resp)
|
|
}
|
|
}
|
|
|
|
func TestRunLLMTurnLoopReturnsRetryLimitError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
provider := &sequenceProvider{
|
|
responses: []*providers.LLMResponse{{
|
|
Content: "",
|
|
ToolCalls: []providers.ToolCall{
|
|
{ID: "tool-1", Name: "system_info", Arguments: map[string]interface{}{}},
|
|
},
|
|
FinishReason: "tool_calls",
|
|
}},
|
|
}
|
|
loop := &AgentLoop{
|
|
provider: provider,
|
|
model: provider.GetDefaultModel(),
|
|
maxIterations: 1,
|
|
tools: toolspkg.NewToolRegistry(),
|
|
providerNames: []string{"sequence"},
|
|
sessionProvider: map[string]string{},
|
|
}
|
|
loop.tools.Register(toolspkg.NewSystemInfoTool())
|
|
_, err := loop.runLLMTurnLoop(llmTurnLoopConfig{
|
|
ctx: context.Background(),
|
|
sessionKey: "cli:test",
|
|
messages: []providers.Message{{Role: "user", Content: "hello"}},
|
|
})
|
|
if err == nil || !strings.Contains(err.Error(), "max tool iterations exceeded") {
|
|
t.Fatalf("expected retry limit error, got %v", err)
|
|
}
|
|
}
|