mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-15 00:27:29 +08:00
fix context
This commit is contained in:
@@ -1680,7 +1680,7 @@ func isModelProviderSelectionError(err error) bool {
|
||||
}
|
||||
|
||||
func shouldRetryWithFallbackModel(err error) bool {
|
||||
return isQuotaOrRateLimitError(err) || isModelProviderSelectionError(err) || isGatewayTransientError(err)
|
||||
return isQuotaOrRateLimitError(err) || isModelProviderSelectionError(err) || isGatewayTransientError(err) || isUpstreamAuthRoutingError(err)
|
||||
}
|
||||
|
||||
func isGatewayTransientError(err error) bool {
|
||||
@@ -1712,6 +1712,26 @@ func isGatewayTransientError(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func isUpstreamAuthRoutingError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
msg := strings.ToLower(err.Error())
|
||||
keywords := []string{
|
||||
"auth_unavailable",
|
||||
"no auth available",
|
||||
"upstream auth unavailable",
|
||||
}
|
||||
|
||||
for _, keyword := range keywords {
|
||||
if strings.Contains(msg, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildProviderToolDefs(toolDefs []map[string]interface{}) ([]providers.ToolDefinition, error) {
|
||||
providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs))
|
||||
for i, td := range toolDefs {
|
||||
@@ -1788,19 +1808,17 @@ func (al *AgentLoop) maybeCompactContext(ctx context.Context, sessionKey string)
|
||||
}
|
||||
|
||||
messageCount := al.sessions.MessageCount(sessionKey)
|
||||
if messageCount < cfg.TriggerMessages {
|
||||
return nil
|
||||
}
|
||||
|
||||
history := al.sessions.GetHistory(sessionKey)
|
||||
if len(history) < cfg.TriggerMessages {
|
||||
summary := al.sessions.GetSummary(sessionKey)
|
||||
triggerByCount := messageCount >= cfg.TriggerMessages && len(history) >= cfg.TriggerMessages
|
||||
triggerBySize := shouldCompactBySize(summary, history, cfg.MaxTranscriptChars)
|
||||
if !triggerByCount && !triggerBySize {
|
||||
return nil
|
||||
}
|
||||
if cfg.KeepRecentMessages >= len(history) {
|
||||
return nil
|
||||
}
|
||||
|
||||
summary := al.sessions.GetSummary(sessionKey)
|
||||
compactUntil := len(history) - cfg.KeepRecentMessages
|
||||
compactCtx, cancel := context.WithTimeout(ctx, 25*time.Second)
|
||||
defer cancel()
|
||||
@@ -1822,12 +1840,9 @@ func (al *AgentLoop) maybeCompactContext(ctx context.Context, sessionKey string)
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Context compacted automatically", map[string]interface{}{
|
||||
"session_key": sessionKey,
|
||||
"before_messages": before,
|
||||
"after_messages": after,
|
||||
"kept_recent": cfg.KeepRecentMessages,
|
||||
"summary_chars": len(newSummary),
|
||||
"trigger_messages": cfg.TriggerMessages,
|
||||
"before": before,
|
||||
"after": after,
|
||||
"trigger_reason": compactionTriggerReason(triggerByCount, triggerBySize),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
@@ -1865,28 +1880,114 @@ func formatCompactionTranscript(messages []providers.Message, maxChars int) stri
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
used := 0
|
||||
lines := make([]string, 0, len(messages))
|
||||
totalChars := 0
|
||||
for _, m := range messages {
|
||||
role := strings.TrimSpace(m.Role)
|
||||
if role == "" {
|
||||
role = "unknown"
|
||||
}
|
||||
line := fmt.Sprintf("[%s] %s\n", role, strings.TrimSpace(m.Content))
|
||||
if len(line) > 1200 {
|
||||
line = truncateString(line, 1200) + "\n"
|
||||
maxLineLen := 1200
|
||||
if role == "tool" {
|
||||
maxLineLen = 420
|
||||
}
|
||||
if used+len(line) > maxChars {
|
||||
remain := maxChars - used
|
||||
if remain > 16 {
|
||||
sb.WriteString(truncateString(line, remain))
|
||||
}
|
||||
if len(line) > maxLineLen {
|
||||
line = truncateString(line, maxLineLen-1) + "\n"
|
||||
}
|
||||
lines = append(lines, line)
|
||||
totalChars += len(line)
|
||||
}
|
||||
|
||||
if totalChars <= maxChars {
|
||||
return strings.TrimSpace(strings.Join(lines, ""))
|
||||
}
|
||||
|
||||
// Keep both early context and recent context when transcript is oversized.
|
||||
headBudget := maxChars / 3
|
||||
if headBudget < 256 {
|
||||
headBudget = maxChars / 2
|
||||
}
|
||||
tailBudget := maxChars - headBudget - 72
|
||||
if tailBudget < 128 {
|
||||
tailBudget = maxChars / 2
|
||||
}
|
||||
|
||||
headEnd := 0
|
||||
usedHead := 0
|
||||
for i, line := range lines {
|
||||
if usedHead+len(line) > headBudget {
|
||||
break
|
||||
}
|
||||
sb.WriteString(line)
|
||||
used += len(line)
|
||||
usedHead += len(line)
|
||||
headEnd = i + 1
|
||||
}
|
||||
return strings.TrimSpace(sb.String())
|
||||
|
||||
tailStart := len(lines)
|
||||
usedTail := 0
|
||||
for i := len(lines) - 1; i >= headEnd; i-- {
|
||||
line := lines[i]
|
||||
if usedTail+len(line) > tailBudget {
|
||||
break
|
||||
}
|
||||
usedTail += len(line)
|
||||
tailStart = i
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for i := 0; i < headEnd; i++ {
|
||||
sb.WriteString(lines[i])
|
||||
}
|
||||
omitted := tailStart - headEnd
|
||||
if omitted > 0 {
|
||||
sb.WriteString(fmt.Sprintf("...[%d messages omitted for compaction]...\n", omitted))
|
||||
}
|
||||
for i := tailStart; i < len(lines); i++ {
|
||||
sb.WriteString(lines[i])
|
||||
}
|
||||
|
||||
out := strings.TrimSpace(sb.String())
|
||||
if len(out) > maxChars {
|
||||
out = truncateString(out, maxChars)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func shouldCompactBySize(summary string, history []providers.Message, maxTranscriptChars int) bool {
|
||||
if maxTranscriptChars <= 0 || len(history) == 0 {
|
||||
return false
|
||||
}
|
||||
return estimateCompactionChars(summary, history) >= maxTranscriptChars
|
||||
}
|
||||
|
||||
func estimateCompactionChars(summary string, history []providers.Message) int {
|
||||
total := len(strings.TrimSpace(summary))
|
||||
for _, msg := range history {
|
||||
total += len(strings.TrimSpace(msg.Role)) + len(strings.TrimSpace(msg.Content)) + 6
|
||||
if msg.ToolCallID != "" {
|
||||
total += len(msg.ToolCallID) + 8
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
total += len(tc.ID) + len(tc.Type) + len(tc.Name)
|
||||
if tc.Function != nil {
|
||||
total += len(tc.Function.Name) + len(tc.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func compactionTriggerReason(byCount, bySize bool) string {
|
||||
if byCount && bySize {
|
||||
return "count+size"
|
||||
}
|
||||
if byCount {
|
||||
return "count"
|
||||
}
|
||||
if bySize {
|
||||
return "size"
|
||||
}
|
||||
return "none"
|
||||
}
|
||||
|
||||
func parseTaskExecutionDirectives(content string) taskExecutionDirectives {
|
||||
@@ -2344,25 +2445,11 @@ func (al *AgentLoop) handleSlashCommand(ctx context.Context, msg bus.InboundMess
|
||||
return true, "", err
|
||||
}
|
||||
path := al.normalizeConfigPathForAgent(fields[2])
|
||||
oldValue, oldExists := al.getMapValueByPathForAgent(cfgMap, path)
|
||||
value := al.parseConfigValueForAgent(strings.Join(fields[3:], " "))
|
||||
if err := al.setMapValueByPathForAgent(cfgMap, path, value); err != nil {
|
||||
return true, "", err
|
||||
}
|
||||
|
||||
modelSwitched := path == "agents.defaults.model" && (!oldExists || strings.TrimSpace(fmt.Sprintf("%v", oldValue)) != strings.TrimSpace(fmt.Sprintf("%v", value)))
|
||||
compactedOnSwitch := false
|
||||
if modelSwitched {
|
||||
if err := al.compactContextBeforeModelSwitch(ctx, msg.SessionKey); err != nil {
|
||||
logger.WarnCF("agent", "Pre-switch context compaction failed", map[string]interface{}{
|
||||
"session_key": msg.SessionKey,
|
||||
logger.FieldError: err.Error(),
|
||||
})
|
||||
} else {
|
||||
compactedOnSwitch = true
|
||||
}
|
||||
}
|
||||
|
||||
al.applyRuntimeModelConfig(path, value)
|
||||
|
||||
data, err := json.MarshalIndent(cfgMap, "", " ")
|
||||
@@ -2384,14 +2471,8 @@ func (al *AgentLoop) handleSlashCommand(ctx context.Context, msg bus.InboundMess
|
||||
}
|
||||
return true, "", fmt.Errorf("hot reload failed, config rolled back: %w", err)
|
||||
}
|
||||
if modelSwitched && compactedOnSwitch {
|
||||
return true, fmt.Sprintf("Updated %s = %v\nContext compacted before model switch\nHot reload not applied: %v", path, value, err), nil
|
||||
}
|
||||
return true, fmt.Sprintf("Updated %s = %v\nHot reload not applied: %v", path, value, err), nil
|
||||
}
|
||||
if modelSwitched && compactedOnSwitch {
|
||||
return true, fmt.Sprintf("Updated %s = %v\nContext compacted before model switch\nGateway hot reload signal sent", path, value), nil
|
||||
}
|
||||
return true, fmt.Sprintf("Updated %s = %v\nGateway hot reload signal sent", path, value), nil
|
||||
default:
|
||||
return true, "Usage: /config get <path> | /config set <path> <value>", nil
|
||||
@@ -2541,71 +2622,6 @@ func (al *AgentLoop) triggerGatewayReloadFromAgent() (bool, error) {
|
||||
return configops.TriggerGatewayReload(al.getConfigPathForCommands(), errGatewayNotRunningSlash)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) compactContextBeforeModelSwitch(ctx context.Context, sessionKey string) error {
|
||||
if strings.TrimSpace(sessionKey) == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
history := al.sessions.GetHistory(sessionKey)
|
||||
if len(history) <= 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
keepRecent := al.compactionCfg.KeepRecentMessages
|
||||
if keepRecent <= 0 {
|
||||
keepRecent = 20
|
||||
}
|
||||
if keepRecent >= len(history) {
|
||||
keepRecent = len(history) / 2
|
||||
}
|
||||
if keepRecent <= 0 || keepRecent >= len(history) {
|
||||
return nil
|
||||
}
|
||||
|
||||
maxSummaryChars := al.compactionCfg.MaxSummaryChars
|
||||
if maxSummaryChars <= 0 {
|
||||
maxSummaryChars = 6000
|
||||
}
|
||||
maxTranscriptChars := al.compactionCfg.MaxTranscriptChars
|
||||
if maxTranscriptChars <= 0 {
|
||||
maxTranscriptChars = 24000
|
||||
}
|
||||
|
||||
compactUntil := len(history) - keepRecent
|
||||
if compactUntil <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
currentSummary := al.sessions.GetSummary(sessionKey)
|
||||
compactCtx, cancel := context.WithTimeout(ctx, 25*time.Second)
|
||||
defer cancel()
|
||||
newSummary, err := al.buildCompactedSummary(compactCtx, currentSummary, history[:compactUntil], maxTranscriptChars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newSummary = strings.TrimSpace(newSummary)
|
||||
if newSummary == "" {
|
||||
return nil
|
||||
}
|
||||
if len(newSummary) > maxSummaryChars {
|
||||
newSummary = truncateString(newSummary, maxSummaryChars)
|
||||
}
|
||||
|
||||
before, after, err := al.sessions.CompactHistory(sessionKey, newSummary, keepRecent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.InfoCF("agent", "Context compacted before model switch", map[string]interface{}{
|
||||
"session_key": sessionKey,
|
||||
"before_messages": before,
|
||||
"after_messages": after,
|
||||
"kept_recent": keepRecent,
|
||||
"summary_chars": len(newSummary),
|
||||
"model_after": al.model,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) applyRuntimeModelConfig(path string, value interface{}) {
|
||||
switch path {
|
||||
case "agents.defaults.model":
|
||||
|
||||
60
pkg/agent/loop_compaction_test.go
Normal file
60
pkg/agent/loop_compaction_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"clawgo/pkg/providers"
|
||||
)
|
||||
|
||||
func TestShouldCompactBySize(t *testing.T) {
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: strings.Repeat("a", 80)},
|
||||
{Role: "assistant", Content: strings.Repeat("b", 80)},
|
||||
}
|
||||
|
||||
if !shouldCompactBySize("", history, 120) {
|
||||
t.Fatalf("expected size-based compaction trigger")
|
||||
}
|
||||
if shouldCompactBySize("", history, 10000) {
|
||||
t.Fatalf("did not expect trigger for large threshold")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatCompactionTranscript_HeadTailWhenOversized(t *testing.T) {
|
||||
msgs := make([]providers.Message, 0, 30)
|
||||
for i := 0; i < 30; i++ {
|
||||
msgs = append(msgs, providers.Message{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf("msg-%02d %s", i, strings.Repeat("x", 80)),
|
||||
})
|
||||
}
|
||||
|
||||
out := formatCompactionTranscript(msgs, 700)
|
||||
if out == "" {
|
||||
t.Fatalf("expected non-empty transcript")
|
||||
}
|
||||
if !strings.Contains(out, "msg-00") {
|
||||
t.Fatalf("expected head messages preserved, got: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "msg-29") {
|
||||
t.Fatalf("expected tail messages preserved, got: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "messages omitted for compaction") {
|
||||
t.Fatalf("expected omitted marker, got: %q", out)
|
||||
}
|
||||
if len(out) > 700 {
|
||||
t.Fatalf("expected output <= max chars, got %d", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatCompactionTranscript_TrimsToolPayloadMoreAggressively(t *testing.T) {
|
||||
msgs := []providers.Message{
|
||||
{Role: "tool", Content: strings.Repeat("z", 2000)},
|
||||
}
|
||||
out := formatCompactionTranscript(msgs, 2000)
|
||||
if len(out) >= 1200 {
|
||||
t.Fatalf("expected tool content to be trimmed aggressively, got length %d", len(out))
|
||||
}
|
||||
}
|
||||
@@ -120,6 +120,35 @@ func TestCallLLMWithModelFallback_RetriesOnGateway524(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallLLMWithModelFallback_RetriesOnAuthUnavailable500(t *testing.T) {
|
||||
p := &fallbackTestProvider{
|
||||
byModel: map[string]fallbackResult{
|
||||
"gemini-3-flash": {err: fmt.Errorf(`API error (status 500, content-type "application/json"): {"error":{"message":"auth_unavailable: no auth available","type":"server_error","code":"internal_server_error"}}`)},
|
||||
"gpt-4o-mini": {resp: &providers.LLMResponse{Content: "ok"}},
|
||||
},
|
||||
}
|
||||
|
||||
al := &AgentLoop{
|
||||
provider: p,
|
||||
model: "gemini-3-flash",
|
||||
modelFallbacks: []string{"gpt-4o-mini"},
|
||||
}
|
||||
|
||||
resp, err := al.callLLMWithModelFallback(context.Background(), nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp == nil || resp.Content != "ok" {
|
||||
t.Fatalf("unexpected response: %+v", resp)
|
||||
}
|
||||
if len(p.called) != 2 {
|
||||
t.Fatalf("expected 2 model attempts, got %d (%v)", len(p.called), p.called)
|
||||
}
|
||||
if p.called[0] != "gemini-3-flash" || p.called[1] != "gpt-4o-mini" {
|
||||
t.Fatalf("unexpected model order: %v", p.called)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallLLMWithModelFallback_NoRetryOnNonRetryableError(t *testing.T) {
|
||||
p := &fallbackTestProvider{
|
||||
byModel: map[string]fallbackResult{
|
||||
@@ -162,3 +191,10 @@ func TestShouldRetryWithFallbackModel_Gateway524Error(t *testing.T) {
|
||||
t.Fatalf("expected 524 gateway timeout to trigger fallback retry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldRetryWithFallbackModel_AuthUnavailableError(t *testing.T) {
|
||||
err := fmt.Errorf(`API error (status 500, content-type "application/json"): {"error":{"message":"auth_unavailable: no auth available","type":"server_error","code":"internal_server_error"}}`)
|
||||
if !shouldRetryWithFallbackModel(err) {
|
||||
t.Fatalf("expected auth_unavailable error to trigger fallback retry")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user