fix context

This commit is contained in:
lpf
2026-02-18 01:01:49 +08:00
parent 90247386f5
commit 75328dcb5d
4 changed files with 230 additions and 111 deletions

View File

@@ -1680,7 +1680,7 @@ func isModelProviderSelectionError(err error) bool {
}
func shouldRetryWithFallbackModel(err error) bool {
return isQuotaOrRateLimitError(err) || isModelProviderSelectionError(err) || isGatewayTransientError(err)
return isQuotaOrRateLimitError(err) || isModelProviderSelectionError(err) || isGatewayTransientError(err) || isUpstreamAuthRoutingError(err)
}
func isGatewayTransientError(err error) bool {
@@ -1712,6 +1712,26 @@ func isGatewayTransientError(err error) bool {
return false
}
func isUpstreamAuthRoutingError(err error) bool {
if err == nil {
return false
}
msg := strings.ToLower(err.Error())
keywords := []string{
"auth_unavailable",
"no auth available",
"upstream auth unavailable",
}
for _, keyword := range keywords {
if strings.Contains(msg, keyword) {
return true
}
}
return false
}
func buildProviderToolDefs(toolDefs []map[string]interface{}) ([]providers.ToolDefinition, error) {
providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs))
for i, td := range toolDefs {
@@ -1788,19 +1808,17 @@ func (al *AgentLoop) maybeCompactContext(ctx context.Context, sessionKey string)
}
messageCount := al.sessions.MessageCount(sessionKey)
if messageCount < cfg.TriggerMessages {
return nil
}
history := al.sessions.GetHistory(sessionKey)
if len(history) < cfg.TriggerMessages {
summary := al.sessions.GetSummary(sessionKey)
triggerByCount := messageCount >= cfg.TriggerMessages && len(history) >= cfg.TriggerMessages
triggerBySize := shouldCompactBySize(summary, history, cfg.MaxTranscriptChars)
if !triggerByCount && !triggerBySize {
return nil
}
if cfg.KeepRecentMessages >= len(history) {
return nil
}
summary := al.sessions.GetSummary(sessionKey)
compactUntil := len(history) - cfg.KeepRecentMessages
compactCtx, cancel := context.WithTimeout(ctx, 25*time.Second)
defer cancel()
@@ -1822,12 +1840,9 @@ func (al *AgentLoop) maybeCompactContext(ctx context.Context, sessionKey string)
}
logger.InfoCF("agent", "Context compacted automatically", map[string]interface{}{
"session_key": sessionKey,
"before_messages": before,
"after_messages": after,
"kept_recent": cfg.KeepRecentMessages,
"summary_chars": len(newSummary),
"trigger_messages": cfg.TriggerMessages,
"before": before,
"after": after,
"trigger_reason": compactionTriggerReason(triggerByCount, triggerBySize),
})
return nil
}
@@ -1865,28 +1880,114 @@ func formatCompactionTranscript(messages []providers.Message, maxChars int) stri
return ""
}
var sb strings.Builder
used := 0
lines := make([]string, 0, len(messages))
totalChars := 0
for _, m := range messages {
role := strings.TrimSpace(m.Role)
if role == "" {
role = "unknown"
}
line := fmt.Sprintf("[%s] %s\n", role, strings.TrimSpace(m.Content))
if len(line) > 1200 {
line = truncateString(line, 1200) + "\n"
maxLineLen := 1200
if role == "tool" {
maxLineLen = 420
}
if used+len(line) > maxChars {
remain := maxChars - used
if remain > 16 {
sb.WriteString(truncateString(line, remain))
}
if len(line) > maxLineLen {
line = truncateString(line, maxLineLen-1) + "\n"
}
lines = append(lines, line)
totalChars += len(line)
}
if totalChars <= maxChars {
return strings.TrimSpace(strings.Join(lines, ""))
}
// Keep both early context and recent context when transcript is oversized.
headBudget := maxChars / 3
if headBudget < 256 {
headBudget = maxChars / 2
}
tailBudget := maxChars - headBudget - 72
if tailBudget < 128 {
tailBudget = maxChars / 2
}
headEnd := 0
usedHead := 0
for i, line := range lines {
if usedHead+len(line) > headBudget {
break
}
sb.WriteString(line)
used += len(line)
usedHead += len(line)
headEnd = i + 1
}
return strings.TrimSpace(sb.String())
tailStart := len(lines)
usedTail := 0
for i := len(lines) - 1; i >= headEnd; i-- {
line := lines[i]
if usedTail+len(line) > tailBudget {
break
}
usedTail += len(line)
tailStart = i
}
var sb strings.Builder
for i := 0; i < headEnd; i++ {
sb.WriteString(lines[i])
}
omitted := tailStart - headEnd
if omitted > 0 {
sb.WriteString(fmt.Sprintf("...[%d messages omitted for compaction]...\n", omitted))
}
for i := tailStart; i < len(lines); i++ {
sb.WriteString(lines[i])
}
out := strings.TrimSpace(sb.String())
if len(out) > maxChars {
out = truncateString(out, maxChars)
}
return out
}
func shouldCompactBySize(summary string, history []providers.Message, maxTranscriptChars int) bool {
if maxTranscriptChars <= 0 || len(history) == 0 {
return false
}
return estimateCompactionChars(summary, history) >= maxTranscriptChars
}
func estimateCompactionChars(summary string, history []providers.Message) int {
total := len(strings.TrimSpace(summary))
for _, msg := range history {
total += len(strings.TrimSpace(msg.Role)) + len(strings.TrimSpace(msg.Content)) + 6
if msg.ToolCallID != "" {
total += len(msg.ToolCallID) + 8
}
for _, tc := range msg.ToolCalls {
total += len(tc.ID) + len(tc.Type) + len(tc.Name)
if tc.Function != nil {
total += len(tc.Function.Name) + len(tc.Function.Arguments)
}
}
}
return total
}
func compactionTriggerReason(byCount, bySize bool) string {
if byCount && bySize {
return "count+size"
}
if byCount {
return "count"
}
if bySize {
return "size"
}
return "none"
}
func parseTaskExecutionDirectives(content string) taskExecutionDirectives {
@@ -2344,25 +2445,11 @@ func (al *AgentLoop) handleSlashCommand(ctx context.Context, msg bus.InboundMess
return true, "", err
}
path := al.normalizeConfigPathForAgent(fields[2])
oldValue, oldExists := al.getMapValueByPathForAgent(cfgMap, path)
value := al.parseConfigValueForAgent(strings.Join(fields[3:], " "))
if err := al.setMapValueByPathForAgent(cfgMap, path, value); err != nil {
return true, "", err
}
modelSwitched := path == "agents.defaults.model" && (!oldExists || strings.TrimSpace(fmt.Sprintf("%v", oldValue)) != strings.TrimSpace(fmt.Sprintf("%v", value)))
compactedOnSwitch := false
if modelSwitched {
if err := al.compactContextBeforeModelSwitch(ctx, msg.SessionKey); err != nil {
logger.WarnCF("agent", "Pre-switch context compaction failed", map[string]interface{}{
"session_key": msg.SessionKey,
logger.FieldError: err.Error(),
})
} else {
compactedOnSwitch = true
}
}
al.applyRuntimeModelConfig(path, value)
data, err := json.MarshalIndent(cfgMap, "", " ")
@@ -2384,14 +2471,8 @@ func (al *AgentLoop) handleSlashCommand(ctx context.Context, msg bus.InboundMess
}
return true, "", fmt.Errorf("hot reload failed, config rolled back: %w", err)
}
if modelSwitched && compactedOnSwitch {
return true, fmt.Sprintf("Updated %s = %v\nContext compacted before model switch\nHot reload not applied: %v", path, value, err), nil
}
return true, fmt.Sprintf("Updated %s = %v\nHot reload not applied: %v", path, value, err), nil
}
if modelSwitched && compactedOnSwitch {
return true, fmt.Sprintf("Updated %s = %v\nContext compacted before model switch\nGateway hot reload signal sent", path, value), nil
}
return true, fmt.Sprintf("Updated %s = %v\nGateway hot reload signal sent", path, value), nil
default:
return true, "Usage: /config get <path> | /config set <path> <value>", nil
@@ -2541,71 +2622,6 @@ func (al *AgentLoop) triggerGatewayReloadFromAgent() (bool, error) {
return configops.TriggerGatewayReload(al.getConfigPathForCommands(), errGatewayNotRunningSlash)
}
func (al *AgentLoop) compactContextBeforeModelSwitch(ctx context.Context, sessionKey string) error {
if strings.TrimSpace(sessionKey) == "" {
return nil
}
history := al.sessions.GetHistory(sessionKey)
if len(history) <= 1 {
return nil
}
keepRecent := al.compactionCfg.KeepRecentMessages
if keepRecent <= 0 {
keepRecent = 20
}
if keepRecent >= len(history) {
keepRecent = len(history) / 2
}
if keepRecent <= 0 || keepRecent >= len(history) {
return nil
}
maxSummaryChars := al.compactionCfg.MaxSummaryChars
if maxSummaryChars <= 0 {
maxSummaryChars = 6000
}
maxTranscriptChars := al.compactionCfg.MaxTranscriptChars
if maxTranscriptChars <= 0 {
maxTranscriptChars = 24000
}
compactUntil := len(history) - keepRecent
if compactUntil <= 0 {
return nil
}
currentSummary := al.sessions.GetSummary(sessionKey)
compactCtx, cancel := context.WithTimeout(ctx, 25*time.Second)
defer cancel()
newSummary, err := al.buildCompactedSummary(compactCtx, currentSummary, history[:compactUntil], maxTranscriptChars)
if err != nil {
return err
}
newSummary = strings.TrimSpace(newSummary)
if newSummary == "" {
return nil
}
if len(newSummary) > maxSummaryChars {
newSummary = truncateString(newSummary, maxSummaryChars)
}
before, after, err := al.sessions.CompactHistory(sessionKey, newSummary, keepRecent)
if err != nil {
return err
}
logger.InfoCF("agent", "Context compacted before model switch", map[string]interface{}{
"session_key": sessionKey,
"before_messages": before,
"after_messages": after,
"kept_recent": keepRecent,
"summary_chars": len(newSummary),
"model_after": al.model,
})
return nil
}
func (al *AgentLoop) applyRuntimeModelConfig(path string, value interface{}) {
switch path {
case "agents.defaults.model":

View File

@@ -0,0 +1,60 @@
package agent
import (
"fmt"
"strings"
"testing"
"clawgo/pkg/providers"
)
func TestShouldCompactBySize(t *testing.T) {
history := []providers.Message{
{Role: "user", Content: strings.Repeat("a", 80)},
{Role: "assistant", Content: strings.Repeat("b", 80)},
}
if !shouldCompactBySize("", history, 120) {
t.Fatalf("expected size-based compaction trigger")
}
if shouldCompactBySize("", history, 10000) {
t.Fatalf("did not expect trigger for large threshold")
}
}
func TestFormatCompactionTranscript_HeadTailWhenOversized(t *testing.T) {
msgs := make([]providers.Message, 0, 30)
for i := 0; i < 30; i++ {
msgs = append(msgs, providers.Message{
Role: "user",
Content: fmt.Sprintf("msg-%02d %s", i, strings.Repeat("x", 80)),
})
}
out := formatCompactionTranscript(msgs, 700)
if out == "" {
t.Fatalf("expected non-empty transcript")
}
if !strings.Contains(out, "msg-00") {
t.Fatalf("expected head messages preserved, got: %q", out)
}
if !strings.Contains(out, "msg-29") {
t.Fatalf("expected tail messages preserved, got: %q", out)
}
if !strings.Contains(out, "messages omitted for compaction") {
t.Fatalf("expected omitted marker, got: %q", out)
}
if len(out) > 700 {
t.Fatalf("expected output <= max chars, got %d", len(out))
}
}
func TestFormatCompactionTranscript_TrimsToolPayloadMoreAggressively(t *testing.T) {
msgs := []providers.Message{
{Role: "tool", Content: strings.Repeat("z", 2000)},
}
out := formatCompactionTranscript(msgs, 2000)
if len(out) >= 1200 {
t.Fatalf("expected tool content to be trimmed aggressively, got length %d", len(out))
}
}

View File

@@ -120,6 +120,35 @@ func TestCallLLMWithModelFallback_RetriesOnGateway524(t *testing.T) {
}
}
func TestCallLLMWithModelFallback_RetriesOnAuthUnavailable500(t *testing.T) {
p := &fallbackTestProvider{
byModel: map[string]fallbackResult{
"gemini-3-flash": {err: fmt.Errorf(`API error (status 500, content-type "application/json"): {"error":{"message":"auth_unavailable: no auth available","type":"server_error","code":"internal_server_error"}}`)},
"gpt-4o-mini": {resp: &providers.LLMResponse{Content: "ok"}},
},
}
al := &AgentLoop{
provider: p,
model: "gemini-3-flash",
modelFallbacks: []string{"gpt-4o-mini"},
}
resp, err := al.callLLMWithModelFallback(context.Background(), nil, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp == nil || resp.Content != "ok" {
t.Fatalf("unexpected response: %+v", resp)
}
if len(p.called) != 2 {
t.Fatalf("expected 2 model attempts, got %d (%v)", len(p.called), p.called)
}
if p.called[0] != "gemini-3-flash" || p.called[1] != "gpt-4o-mini" {
t.Fatalf("unexpected model order: %v", p.called)
}
}
func TestCallLLMWithModelFallback_NoRetryOnNonRetryableError(t *testing.T) {
p := &fallbackTestProvider{
byModel: map[string]fallbackResult{
@@ -162,3 +191,10 @@ func TestShouldRetryWithFallbackModel_Gateway524Error(t *testing.T) {
t.Fatalf("expected 524 gateway timeout to trigger fallback retry")
}
}
func TestShouldRetryWithFallbackModel_AuthUnavailableError(t *testing.T) {
err := fmt.Errorf(`API error (status 500, content-type "application/json"): {"error":{"message":"auth_unavailable: no auth available","type":"server_error","code":"internal_server_error"}}`)
if !shouldRetryWithFallbackModel(err) {
t.Fatalf("expected auth_unavailable error to trigger fallback retry")
}
}