diff --git a/config.example.json b/config.example.json index 3da2777..0cd38a1 100644 --- a/config.example.json +++ b/config.example.json @@ -6,7 +6,14 @@ "model_fallbacks": ["gpt-4o-mini", "deepseek-chat"], "max_tokens": 8192, "temperature": 0.7, - "max_tool_iterations": 20 + "max_tool_iterations": 20, + "context_compaction": { + "enabled": true, + "trigger_messages": 60, + "keep_recent_messages": 20, + "max_summary_chars": 6000, + "max_transcript_chars": 20000 + } } }, "channels": { diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 76207e0..25c4813 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -1680,7 +1680,7 @@ func isModelProviderSelectionError(err error) bool { } func shouldRetryWithFallbackModel(err error) bool { - return isQuotaOrRateLimitError(err) || isModelProviderSelectionError(err) || isGatewayTransientError(err) + return isQuotaOrRateLimitError(err) || isModelProviderSelectionError(err) || isGatewayTransientError(err) || isUpstreamAuthRoutingError(err) } func isGatewayTransientError(err error) bool { @@ -1712,6 +1712,26 @@ func isGatewayTransientError(err error) bool { return false } +func isUpstreamAuthRoutingError(err error) bool { + if err == nil { + return false + } + + msg := strings.ToLower(err.Error()) + keywords := []string{ + "auth_unavailable", + "no auth available", + "upstream auth unavailable", + } + + for _, keyword := range keywords { + if strings.Contains(msg, keyword) { + return true + } + } + return false +} + func buildProviderToolDefs(toolDefs []map[string]interface{}) ([]providers.ToolDefinition, error) { providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs)) for i, td := range toolDefs { @@ -1788,19 +1808,17 @@ func (al *AgentLoop) maybeCompactContext(ctx context.Context, sessionKey string) } messageCount := al.sessions.MessageCount(sessionKey) - if messageCount < cfg.TriggerMessages { - return nil - } - history := al.sessions.GetHistory(sessionKey) - if len(history) < cfg.TriggerMessages { + summary := al.sessions.GetSummary(sessionKey) + triggerByCount := messageCount >= cfg.TriggerMessages && len(history) >= cfg.TriggerMessages + triggerBySize := shouldCompactBySize(summary, history, cfg.MaxTranscriptChars) + if !triggerByCount && !triggerBySize { return nil } if cfg.KeepRecentMessages >= len(history) { return nil } - summary := al.sessions.GetSummary(sessionKey) compactUntil := len(history) - cfg.KeepRecentMessages compactCtx, cancel := context.WithTimeout(ctx, 25*time.Second) defer cancel() @@ -1822,12 +1840,9 @@ func (al *AgentLoop) maybeCompactContext(ctx context.Context, sessionKey string) } logger.InfoCF("agent", "Context compacted automatically", map[string]interface{}{ - "session_key": sessionKey, - "before_messages": before, - "after_messages": after, - "kept_recent": cfg.KeepRecentMessages, - "summary_chars": len(newSummary), - "trigger_messages": cfg.TriggerMessages, + "before": before, + "after": after, + "trigger_reason": compactionTriggerReason(triggerByCount, triggerBySize), }) return nil } @@ -1865,28 +1880,114 @@ func formatCompactionTranscript(messages []providers.Message, maxChars int) stri return "" } - var sb strings.Builder - used := 0 + lines := make([]string, 0, len(messages)) + totalChars := 0 for _, m := range messages { role := strings.TrimSpace(m.Role) if role == "" { role = "unknown" } line := fmt.Sprintf("[%s] %s\n", role, strings.TrimSpace(m.Content)) - if len(line) > 1200 { - line = truncateString(line, 1200) + "\n" + maxLineLen := 1200 + if role == "tool" { + maxLineLen = 420 } - if used+len(line) > maxChars { - remain := maxChars - used - if remain > 16 { - sb.WriteString(truncateString(line, remain)) - } + if len(line) > maxLineLen { + line = truncateString(line, maxLineLen-1) + "\n" + } + lines = append(lines, line) + totalChars += len(line) + } + + if totalChars <= maxChars { + return strings.TrimSpace(strings.Join(lines, "")) + } + + // Keep both early context and recent context when transcript is oversized. + headBudget := maxChars / 3 + if headBudget < 256 { + headBudget = maxChars / 2 + } + tailBudget := maxChars - headBudget - 72 + if tailBudget < 128 { + tailBudget = maxChars / 2 + } + + headEnd := 0 + usedHead := 0 + for i, line := range lines { + if usedHead+len(line) > headBudget { break } - sb.WriteString(line) - used += len(line) + usedHead += len(line) + headEnd = i + 1 } - return strings.TrimSpace(sb.String()) + + tailStart := len(lines) + usedTail := 0 + for i := len(lines) - 1; i >= headEnd; i-- { + line := lines[i] + if usedTail+len(line) > tailBudget { + break + } + usedTail += len(line) + tailStart = i + } + + var sb strings.Builder + for i := 0; i < headEnd; i++ { + sb.WriteString(lines[i]) + } + omitted := tailStart - headEnd + if omitted > 0 { + sb.WriteString(fmt.Sprintf("...[%d messages omitted for compaction]...\n", omitted)) + } + for i := tailStart; i < len(lines); i++ { + sb.WriteString(lines[i]) + } + + out := strings.TrimSpace(sb.String()) + if len(out) > maxChars { + out = truncateString(out, maxChars) + } + return out +} + +func shouldCompactBySize(summary string, history []providers.Message, maxTranscriptChars int) bool { + if maxTranscriptChars <= 0 || len(history) == 0 { + return false + } + return estimateCompactionChars(summary, history) >= maxTranscriptChars +} + +func estimateCompactionChars(summary string, history []providers.Message) int { + total := len(strings.TrimSpace(summary)) + for _, msg := range history { + total += len(strings.TrimSpace(msg.Role)) + len(strings.TrimSpace(msg.Content)) + 6 + if msg.ToolCallID != "" { + total += len(msg.ToolCallID) + 8 + } + for _, tc := range msg.ToolCalls { + total += len(tc.ID) + len(tc.Type) + len(tc.Name) + if tc.Function != nil { + total += len(tc.Function.Name) + len(tc.Function.Arguments) + } + } + } + return total +} + +func compactionTriggerReason(byCount, bySize bool) string { + if byCount && bySize { + return "count+size" + } + if byCount { + return "count" + } + if bySize { + return "size" + } + return "none" } func parseTaskExecutionDirectives(content string) taskExecutionDirectives { @@ -2344,25 +2445,11 @@ func (al *AgentLoop) handleSlashCommand(ctx context.Context, msg bus.InboundMess return true, "", err } path := al.normalizeConfigPathForAgent(fields[2]) - oldValue, oldExists := al.getMapValueByPathForAgent(cfgMap, path) value := al.parseConfigValueForAgent(strings.Join(fields[3:], " ")) if err := al.setMapValueByPathForAgent(cfgMap, path, value); err != nil { return true, "", err } - modelSwitched := path == "agents.defaults.model" && (!oldExists || strings.TrimSpace(fmt.Sprintf("%v", oldValue)) != strings.TrimSpace(fmt.Sprintf("%v", value))) - compactedOnSwitch := false - if modelSwitched { - if err := al.compactContextBeforeModelSwitch(ctx, msg.SessionKey); err != nil { - logger.WarnCF("agent", "Pre-switch context compaction failed", map[string]interface{}{ - "session_key": msg.SessionKey, - logger.FieldError: err.Error(), - }) - } else { - compactedOnSwitch = true - } - } - al.applyRuntimeModelConfig(path, value) data, err := json.MarshalIndent(cfgMap, "", " ") @@ -2384,14 +2471,8 @@ func (al *AgentLoop) handleSlashCommand(ctx context.Context, msg bus.InboundMess } return true, "", fmt.Errorf("hot reload failed, config rolled back: %w", err) } - if modelSwitched && compactedOnSwitch { - return true, fmt.Sprintf("Updated %s = %v\nContext compacted before model switch\nHot reload not applied: %v", path, value, err), nil - } return true, fmt.Sprintf("Updated %s = %v\nHot reload not applied: %v", path, value, err), nil } - if modelSwitched && compactedOnSwitch { - return true, fmt.Sprintf("Updated %s = %v\nContext compacted before model switch\nGateway hot reload signal sent", path, value), nil - } return true, fmt.Sprintf("Updated %s = %v\nGateway hot reload signal sent", path, value), nil default: return true, "Usage: /config get | /config set ", nil @@ -2541,71 +2622,6 @@ func (al *AgentLoop) triggerGatewayReloadFromAgent() (bool, error) { return configops.TriggerGatewayReload(al.getConfigPathForCommands(), errGatewayNotRunningSlash) } -func (al *AgentLoop) compactContextBeforeModelSwitch(ctx context.Context, sessionKey string) error { - if strings.TrimSpace(sessionKey) == "" { - return nil - } - - history := al.sessions.GetHistory(sessionKey) - if len(history) <= 1 { - return nil - } - - keepRecent := al.compactionCfg.KeepRecentMessages - if keepRecent <= 0 { - keepRecent = 20 - } - if keepRecent >= len(history) { - keepRecent = len(history) / 2 - } - if keepRecent <= 0 || keepRecent >= len(history) { - return nil - } - - maxSummaryChars := al.compactionCfg.MaxSummaryChars - if maxSummaryChars <= 0 { - maxSummaryChars = 6000 - } - maxTranscriptChars := al.compactionCfg.MaxTranscriptChars - if maxTranscriptChars <= 0 { - maxTranscriptChars = 24000 - } - - compactUntil := len(history) - keepRecent - if compactUntil <= 0 { - return nil - } - - currentSummary := al.sessions.GetSummary(sessionKey) - compactCtx, cancel := context.WithTimeout(ctx, 25*time.Second) - defer cancel() - newSummary, err := al.buildCompactedSummary(compactCtx, currentSummary, history[:compactUntil], maxTranscriptChars) - if err != nil { - return err - } - newSummary = strings.TrimSpace(newSummary) - if newSummary == "" { - return nil - } - if len(newSummary) > maxSummaryChars { - newSummary = truncateString(newSummary, maxSummaryChars) - } - - before, after, err := al.sessions.CompactHistory(sessionKey, newSummary, keepRecent) - if err != nil { - return err - } - logger.InfoCF("agent", "Context compacted before model switch", map[string]interface{}{ - "session_key": sessionKey, - "before_messages": before, - "after_messages": after, - "kept_recent": keepRecent, - "summary_chars": len(newSummary), - "model_after": al.model, - }) - return nil -} - func (al *AgentLoop) applyRuntimeModelConfig(path string, value interface{}) { switch path { case "agents.defaults.model": diff --git a/pkg/agent/loop_compaction_test.go b/pkg/agent/loop_compaction_test.go new file mode 100644 index 0000000..f6a14a8 --- /dev/null +++ b/pkg/agent/loop_compaction_test.go @@ -0,0 +1,60 @@ +package agent + +import ( + "fmt" + "strings" + "testing" + + "clawgo/pkg/providers" +) + +func TestShouldCompactBySize(t *testing.T) { + history := []providers.Message{ + {Role: "user", Content: strings.Repeat("a", 80)}, + {Role: "assistant", Content: strings.Repeat("b", 80)}, + } + + if !shouldCompactBySize("", history, 120) { + t.Fatalf("expected size-based compaction trigger") + } + if shouldCompactBySize("", history, 10000) { + t.Fatalf("did not expect trigger for large threshold") + } +} + +func TestFormatCompactionTranscript_HeadTailWhenOversized(t *testing.T) { + msgs := make([]providers.Message, 0, 30) + for i := 0; i < 30; i++ { + msgs = append(msgs, providers.Message{ + Role: "user", + Content: fmt.Sprintf("msg-%02d %s", i, strings.Repeat("x", 80)), + }) + } + + out := formatCompactionTranscript(msgs, 700) + if out == "" { + t.Fatalf("expected non-empty transcript") + } + if !strings.Contains(out, "msg-00") { + t.Fatalf("expected head messages preserved, got: %q", out) + } + if !strings.Contains(out, "msg-29") { + t.Fatalf("expected tail messages preserved, got: %q", out) + } + if !strings.Contains(out, "messages omitted for compaction") { + t.Fatalf("expected omitted marker, got: %q", out) + } + if len(out) > 700 { + t.Fatalf("expected output <= max chars, got %d", len(out)) + } +} + +func TestFormatCompactionTranscript_TrimsToolPayloadMoreAggressively(t *testing.T) { + msgs := []providers.Message{ + {Role: "tool", Content: strings.Repeat("z", 2000)}, + } + out := formatCompactionTranscript(msgs, 2000) + if len(out) >= 1200 { + t.Fatalf("expected tool content to be trimmed aggressively, got length %d", len(out)) + } +} diff --git a/pkg/agent/loop_fallback_test.go b/pkg/agent/loop_fallback_test.go index a625bea..4fabe92 100644 --- a/pkg/agent/loop_fallback_test.go +++ b/pkg/agent/loop_fallback_test.go @@ -120,6 +120,35 @@ func TestCallLLMWithModelFallback_RetriesOnGateway524(t *testing.T) { } } +func TestCallLLMWithModelFallback_RetriesOnAuthUnavailable500(t *testing.T) { + p := &fallbackTestProvider{ + byModel: map[string]fallbackResult{ + "gemini-3-flash": {err: fmt.Errorf(`API error (status 500, content-type "application/json"): {"error":{"message":"auth_unavailable: no auth available","type":"server_error","code":"internal_server_error"}}`)}, + "gpt-4o-mini": {resp: &providers.LLMResponse{Content: "ok"}}, + }, + } + + al := &AgentLoop{ + provider: p, + model: "gemini-3-flash", + modelFallbacks: []string{"gpt-4o-mini"}, + } + + resp, err := al.callLLMWithModelFallback(context.Background(), nil, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp == nil || resp.Content != "ok" { + t.Fatalf("unexpected response: %+v", resp) + } + if len(p.called) != 2 { + t.Fatalf("expected 2 model attempts, got %d (%v)", len(p.called), p.called) + } + if p.called[0] != "gemini-3-flash" || p.called[1] != "gpt-4o-mini" { + t.Fatalf("unexpected model order: %v", p.called) + } +} + func TestCallLLMWithModelFallback_NoRetryOnNonRetryableError(t *testing.T) { p := &fallbackTestProvider{ byModel: map[string]fallbackResult{ @@ -162,3 +191,10 @@ func TestShouldRetryWithFallbackModel_Gateway524Error(t *testing.T) { t.Fatalf("expected 524 gateway timeout to trigger fallback retry") } } + +func TestShouldRetryWithFallbackModel_AuthUnavailableError(t *testing.T) { + err := fmt.Errorf(`API error (status 500, content-type "application/json"): {"error":{"message":"auth_unavailable: no auth available","type":"server_error","code":"internal_server_error"}}`) + if !shouldRetryWithFallbackModel(err) { + t.Fatalf("expected auth_unavailable error to trigger fallback retry") + } +}