diff --git a/README.md b/README.md index 7183818..b801b9d 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,7 @@ clawgo config reload - `config set` 使用原子写入。 - 网关运行时若热更新失败,会自动回滚备份,避免损坏配置。 - `--config` 指定的自定义配置路径会被 `config` 命令与通道内 `/config` 指令一致使用。 +- 配置加载使用严格 JSON 解析:未知字段与多余 JSON 内容会直接报错,避免拼写错误被静默忽略。 ## 🌐 通道与消息控制 @@ -106,20 +107,22 @@ clawgo config reload /help /stop /status +/status run [run_id|latest] +/status wait [timeout_seconds] /config get /config set /reload -/autonomy start [idle] -/autonomy stop -/autonomy status -/autolearn start [interval] -/autolearn stop -/autolearn status /pipeline list /pipeline status /pipeline ready ``` +自主与学习控制默认使用自然语言,不再依赖斜杠命令。例如: +- `开始自主模式,每 30 分钟巡检一次` +- `停止自动学习` +- `看看最新 run 的状态` +- `等待 run-1739950000000000000-8 完成后告诉我结果` + 调度语义(按 `session_key`): - 同会话严格 FIFO 串行处理。 - `/stop` 会中断当前回复并继续队列后续消息。 @@ -156,6 +159,38 @@ clawgo channel test --channel telegram --to -m "ping" } ``` +运行控制配置示例(意图阈值 / 自主循环守卫 / 运行态保留): + +```json +"agents": { + "defaults": { + "runtime_control": { + "intent_high_confidence": 0.75, + "intent_confirm_min_confidence": 0.45, + "intent_max_input_chars": 1200, + "confirm_ttl_seconds": 300, + "confirm_max_clarification_turns": 2, + "autonomy_tick_interval_sec": 20, + "autonomy_min_run_interval_sec": 20, + "autonomy_idle_threshold_sec": 20, + "autonomy_max_rounds_without_user": 120, + "autonomy_max_pending_duration_sec": 180, + "autonomy_max_consecutive_stalls": 3, + "autolearn_max_rounds_without_user": 200, + "run_state_ttl_seconds": 1800, + "run_state_max": 500, + "run_control_latest_keywords": ["latest", "last run", "recent run", "最新", "最近", "上一次", "上个"], + "run_control_wait_keywords": ["wait", "等待", "等到", "阻塞"], + "run_control_status_keywords": ["status", "状态", "进度", "running", "运行"], + "run_control_run_mention_keywords": ["run", "任务"], + "run_control_minute_units": ["分钟", "min", "mins", "minute", "minutes", "m"], + "tool_parallel_safe_names": ["read_file", "list_files", "find_files", "grep_files", "memory_search", "web_search", "repo_map", "system_info"], + "tool_max_parallel_calls": 2 + } + } +} +``` + ## 🤖 多智能体编排 (Pipeline) 内置标准化编排工具: diff --git a/README_EN.md b/README_EN.md index 563efd3..de67a79 100644 --- a/README_EN.md +++ b/README_EN.md @@ -97,6 +97,7 @@ Notes: - `config set` uses atomic write. - If gateway reload fails while running, config auto-rolls back from backup. - Custom `--config` path is consistently used by CLI config commands and in-channel `/config` commands. +- Config loading uses strict JSON decoding: unknown fields and trailing JSON content now fail fast. ## 🌐 Channels and Message Control @@ -106,20 +107,22 @@ Supported in-channel slash commands: /help /stop /status +/status run [run_id|latest] +/status wait [timeout_seconds] /config get /config set /reload -/autonomy start [idle] -/autonomy stop -/autonomy status -/autolearn start [interval] -/autolearn stop -/autolearn status /pipeline list /pipeline status /pipeline ready ``` +Autonomy and auto-learn control now default to natural language (no slash commands required). Examples: +- `start autonomy mode and check every 30 minutes` +- `stop auto-learn` +- `show latest run status` +- `wait for run-1739950000000000000-8 and report when done` + Scheduling semantics (`session_key` based): - Strict FIFO processing per session. - `/stop` interrupts current response and continues queued messages. @@ -156,6 +159,38 @@ Context compaction config example: } ``` +Runtime-control config example (intent thresholds / autonomy guards / run-state retention): + +```json +"agents": { + "defaults": { + "runtime_control": { + "intent_high_confidence": 0.75, + "intent_confirm_min_confidence": 0.45, + "intent_max_input_chars": 1200, + "confirm_ttl_seconds": 300, + "confirm_max_clarification_turns": 2, + "autonomy_tick_interval_sec": 20, + "autonomy_min_run_interval_sec": 20, + "autonomy_idle_threshold_sec": 20, + "autonomy_max_rounds_without_user": 120, + "autonomy_max_pending_duration_sec": 180, + "autonomy_max_consecutive_stalls": 3, + "autolearn_max_rounds_without_user": 200, + "run_state_ttl_seconds": 1800, + "run_state_max": 500, + "run_control_latest_keywords": ["latest", "last run", "recent run", "最新", "最近", "上一次", "上个"], + "run_control_wait_keywords": ["wait", "等待", "等到", "阻塞"], + "run_control_status_keywords": ["status", "状态", "进度", "running", "运行"], + "run_control_run_mention_keywords": ["run", "任务"], + "run_control_minute_units": ["分钟", "min", "mins", "minute", "minutes", "m"], + "tool_parallel_safe_names": ["read_file", "list_files", "find_files", "grep_files", "memory_search", "web_search", "repo_map", "system_info"], + "tool_max_parallel_calls": 2 + } + } +} +``` + ## 🤖 Multi-Agent Orchestration (Pipeline) Built-in orchestration tools: diff --git a/config.example.json b/config.example.json index b6d9ace..f3fe38a 100644 --- a/config.example.json +++ b/config.example.json @@ -14,6 +14,29 @@ "keep_recent_messages": 20, "max_summary_chars": 6000, "max_transcript_chars": 20000 + }, + "runtime_control": { + "intent_high_confidence": 0.75, + "intent_confirm_min_confidence": 0.45, + "intent_max_input_chars": 1200, + "confirm_ttl_seconds": 300, + "confirm_max_clarification_turns": 2, + "autonomy_tick_interval_sec": 20, + "autonomy_min_run_interval_sec": 20, + "autonomy_idle_threshold_sec": 20, + "autonomy_max_rounds_without_user": 120, + "autonomy_max_pending_duration_sec": 180, + "autonomy_max_consecutive_stalls": 3, + "autolearn_max_rounds_without_user": 200, + "run_state_ttl_seconds": 1800, + "run_state_max": 500, + "run_control_latest_keywords": ["latest", "last run", "recent run", "最新", "最近", "上一次", "上个"], + "run_control_wait_keywords": ["wait", "等待", "等到", "阻塞"], + "run_control_status_keywords": ["status", "状态", "进度", "running", "运行"], + "run_control_run_mention_keywords": ["run", "任务"], + "run_control_minute_units": ["分钟", "min", "mins", "minute", "minutes", "m"], + "tool_parallel_safe_names": ["read_file", "list_files", "find_files", "grep_files", "memory_search", "web_search", "repo_map", "system_info"], + "tool_max_parallel_calls": 2 } } }, diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 38f048f..e47efce 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -41,6 +41,28 @@ const autonomyDefaultIdleInterval = 30 * time.Minute const autonomyMinIdleInterval = 1 * time.Minute const autonomyContinuousRunInterval = 20 * time.Second const autonomyContinuousIdleThreshold = 20 * time.Second +const defaultRunStateTTL = 30 * time.Minute +const defaultRunStateMaxEntries = 500 +const defaultRunWaitTimeout = 60 * time.Second +const minRunWaitTimeout = 5 * time.Second +const maxRunWaitTimeout = 15 * time.Minute +const toolLoopRepeatSignatureThreshold = 2 +const toolLoopAllErrorRoundsThreshold = 2 +const toolLoopMaxCallsPerIteration = 6 +const toolLoopSingleCallTimeout = 20 * time.Second +const toolLoopMaxActDuration = 45 * time.Second +const toolLoopReflectTimeout = 6 * time.Second +const toolLoopMinCallsPerIteration = 2 +const toolLoopMinSingleCallTimeout = 8 * time.Second +const toolLoopMinActDuration = 18 * time.Second +const toolLoopMaxParallelCalls = 2 +const finalizeDraftMinCharsForPolish = 90 +const finalizeQualityThreshold = 0.72 +const finalizeHeuristicHighThreshold = 0.82 +const finalizeHeuristicLowThreshold = 0.48 +const reflectionCooldownRounds = 2 +const toolSummaryMaxRecords = 4 +const maxSelfRepairPasses = 2 type sessionWorker struct { queue chan bus.InboundMessage @@ -84,7 +106,20 @@ type controlPolicy struct { autoLearnMaxRoundsWithoutUser int } +type runControlLexicon struct { + latestKeywords []string + waitKeywords []string + statusKeywords []string + runMentionKeywords []string + minuteUnits map[string]struct{} +} + type runtimeControlStats struct { + runAccepted int64 + runCompleted int64 + runFailed int64 + runCanceled int64 + runControlHandled int64 intentAutonomyMatched int64 intentAutonomyNeedsConfirm int64 intentAutonomyRejected int64 @@ -101,33 +136,80 @@ type runtimeControlStats struct { autoLearnStoppedByGuard int64 } +type runStatus string + +const ( + runStatusAccepted runStatus = "accepted" + runStatusRunning runStatus = "running" + runStatusOK runStatus = "ok" + runStatusError runStatus = "error" + runStatusCanceled runStatus = "canceled" +) + +type agentRunLifecycle struct { + runID string + acceptedAt time.Time + sessionKey string + channel string + senderID string + chatID string + synthetic bool + controlEligible bool +} + +type runState struct { + runID string + sessionKey string + channel string + chatID string + senderID string + synthetic bool + controlEligible bool + status runStatus + acceptedAt time.Time + startedAt time.Time + endedAt time.Time + errMessage string + responseLen int + controlHandled bool + done chan struct{} +} + type AgentLoop struct { - bus *bus.MessageBus - provider providers.LLMProvider - providersByProxy map[string]providers.LLMProvider - modelsByProxy map[string][]string - proxy string - proxyFallbacks []string - workspace string - model string - maxIterations int - sessions *session.SessionManager - contextBuilder *ContextBuilder - tools *tools.ToolRegistry - orchestrator *tools.Orchestrator - running atomic.Bool - compactionCfg config.ContextCompactionConfig - llmCallTimeout time.Duration - workersMu sync.Mutex - workers map[string]*sessionWorker - autoLearnMu sync.Mutex - autoLearners map[string]*autoLearner - autonomyMu sync.Mutex - autonomyBySess map[string]*autonomySession - controlConfirmMu sync.Mutex - controlConfirm map[string]pendingControlConfirmation - controlPolicy controlPolicy - controlStats runtimeControlStats + bus *bus.MessageBus + provider providers.LLMProvider + providersByProxy map[string]providers.LLMProvider + modelsByProxy map[string][]string + proxy string + proxyFallbacks []string + workspace string + model string + maxIterations int + sessions *session.SessionManager + contextBuilder *ContextBuilder + tools *tools.ToolRegistry + orchestrator *tools.Orchestrator + running atomic.Bool + compactionCfg config.ContextCompactionConfig + llmCallTimeout time.Duration + workersMu sync.Mutex + workers map[string]*sessionWorker + autoLearnMu sync.Mutex + autoLearners map[string]*autoLearner + autonomyMu sync.Mutex + autonomyBySess map[string]*autonomySession + controlConfirmMu sync.Mutex + controlConfirm map[string]pendingControlConfirmation + controlPolicy controlPolicy + runControlLex runControlLexicon + parallelSafeTools map[string]struct{} + maxParallelCalls int + controlStats runtimeControlStats + runSeq atomic.Int64 + runStateMu sync.Mutex + runStates map[string]*runState + runStateTTL time.Duration + runStateMax int } type taskExecutionDirectives struct { @@ -135,6 +217,13 @@ type taskExecutionDirectives struct { stageReport bool } +type runControlIntent struct { + runID string + latest bool + wait bool + timeout time.Duration +} + type autoLearnIntent struct { action string interval *time.Duration @@ -171,6 +260,22 @@ type taskExecutionDirectivesLLMResponse struct { Confidence float64 `json:"confidence"` } +var runIDPattern = regexp.MustCompile(`(?i)\b(run-\d+-\d+)\b`) +var runWaitTimeoutPattern = regexp.MustCompile(`(?i)(\d+)\s*(seconds|second|secs|sec|minutes|minute|mins|min|分钟|秒|s|m)`) +var defaultRunControlLatestKeywords = []string{"latest", "last run", "recent run", "最新", "最近", "上一次", "上个"} +var defaultRunControlWaitKeywords = []string{"wait", "等待", "等到", "阻塞"} +var defaultRunControlStatusKeywords = []string{"status", "状态", "进度", "running", "运行"} +var defaultRunControlRunMentionKeywords = []string{"run", "任务"} +var defaultParallelSafeToolNames = []string{"read_file", "list_files", "find_files", "grep_files", "memory_search", "web_search", "repo_map", "system_info"} +var defaultRunWaitMinuteUnits = map[string]struct{}{ + "分钟": {}, + "min": {}, + "mins": {}, + "minute": {}, + "minutes": {}, + "m": {}, +} + type stageReporter struct { onUpdate func(content string) localize func(content string) string @@ -254,7 +359,139 @@ func envDuration(key string, fallback time.Duration) time.Duration { return d } -func loadControlPolicyFromEnv(base controlPolicy) controlPolicy { +func loadControlPolicyFromConfig(base controlPolicy, rc config.RuntimeControlConfig) controlPolicy { + p := base + if rc.IntentHighConfidence > 0 { + p.intentHighConfidence = rc.IntentHighConfidence + } + if rc.IntentConfirmMinConfidence >= 0 { + p.intentConfirmMinConfidence = rc.IntentConfirmMinConfidence + } + if rc.IntentMaxInputChars > 0 { + p.intentMaxInputChars = rc.IntentMaxInputChars + } + if rc.ConfirmTTLSeconds > 0 { + p.confirmTTL = time.Duration(rc.ConfirmTTLSeconds) * time.Second + } + if rc.ConfirmMaxClarificationTurns >= 0 { + p.confirmMaxClarificationTurns = rc.ConfirmMaxClarificationTurns + } + if rc.AutonomyTickIntervalSec > 0 { + p.autonomyTickInterval = time.Duration(rc.AutonomyTickIntervalSec) * time.Second + } + if rc.AutonomyMinRunIntervalSec > 0 { + p.autonomyMinRunInterval = time.Duration(rc.AutonomyMinRunIntervalSec) * time.Second + } + if rc.AutonomyIdleThresholdSec > 0 { + p.autonomyIdleThreshold = time.Duration(rc.AutonomyIdleThresholdSec) * time.Second + } + if rc.AutonomyMaxRoundsWithoutUser > 0 { + p.autonomyMaxRoundsWithoutUser = rc.AutonomyMaxRoundsWithoutUser + } + if rc.AutonomyMaxPendingDurationSec > 0 { + p.autonomyMaxPendingDuration = time.Duration(rc.AutonomyMaxPendingDurationSec) * time.Second + } + if rc.AutonomyMaxConsecutiveStalls > 0 { + p.autonomyMaxConsecutiveStalls = rc.AutonomyMaxConsecutiveStalls + } + if rc.AutoLearnMaxRoundsWithoutUser > 0 { + p.autoLearnMaxRoundsWithoutUser = rc.AutoLearnMaxRoundsWithoutUser + } + return p +} + +func loadRunStatePolicyFromConfig(rc config.RuntimeControlConfig) (time.Duration, int) { + ttl := defaultRunStateTTL + if rc.RunStateTTLSeconds > 0 { + ttl = time.Duration(rc.RunStateTTLSeconds) * time.Second + } + maxEntries := defaultRunStateMaxEntries + if rc.RunStateMax > 0 { + maxEntries = rc.RunStateMax + } + return ttl, maxEntries +} + +func defaultRunControlLexicon() runControlLexicon { + latest := append([]string(nil), defaultRunControlLatestKeywords...) + wait := append([]string(nil), defaultRunControlWaitKeywords...) + status := append([]string(nil), defaultRunControlStatusKeywords...) + mention := append([]string(nil), defaultRunControlRunMentionKeywords...) + minutes := make(map[string]struct{}, len(defaultRunWaitMinuteUnits)) + for unit := range defaultRunWaitMinuteUnits { + minutes[unit] = struct{}{} + } + return runControlLexicon{ + latestKeywords: latest, + waitKeywords: wait, + statusKeywords: status, + runMentionKeywords: mention, + minuteUnits: minutes, + } +} + +func loadRunControlLexiconFromConfig(rc config.RuntimeControlConfig) runControlLexicon { + base := defaultRunControlLexicon() + base.latestKeywords = normalizeKeywordList(rc.RunControlLatestKeywords, base.latestKeywords) + base.waitKeywords = normalizeKeywordList(rc.RunControlWaitKeywords, base.waitKeywords) + base.statusKeywords = normalizeKeywordList(rc.RunControlStatusKeywords, base.statusKeywords) + base.runMentionKeywords = normalizeKeywordList(rc.RunControlRunMentionKeywords, base.runMentionKeywords) + + minuteUnits := normalizeKeywordList(rc.RunControlMinuteUnits, nil) + if len(minuteUnits) == 0 { + return base + } + base.minuteUnits = make(map[string]struct{}, len(minuteUnits)) + for _, unit := range minuteUnits { + base.minuteUnits[unit] = struct{}{} + } + return base +} + +func normalizeKeywordList(values []string, fallback []string) []string { + if len(values) == 0 { + return append([]string(nil), fallback...) + } + out := make([]string, 0, len(values)) + seen := make(map[string]struct{}, len(values)) + for _, value := range values { + normalized := strings.ToLower(strings.TrimSpace(value)) + if normalized == "" { + continue + } + if _, ok := seen[normalized]; ok { + continue + } + seen[normalized] = struct{}{} + out = append(out, normalized) + } + if len(out) == 0 { + return append([]string(nil), fallback...) + } + return out +} + +func loadToolParallelPolicyFromConfig(rc config.RuntimeControlConfig) (map[string]struct{}, int) { + names := normalizeKeywordList(rc.ToolParallelSafeNames, defaultParallelSafeToolNames) + allowed := make(map[string]struct{}, len(names)) + for _, name := range names { + allowed[name] = struct{}{} + } + maxParallel := rc.ToolMaxParallelCalls + if maxParallel <= 0 { + maxParallel = toolLoopMaxParallelCalls + } + if maxParallel < 1 { + maxParallel = 1 + } + if maxParallel > 8 { + maxParallel = 8 + } + return allowed, maxParallel +} + +// applyLegacyControlPolicyEnvOverrides keeps compatibility with older env names. +func applyLegacyControlPolicyEnvOverrides(base controlPolicy) controlPolicy { p := base p.intentHighConfidence = envFloat64("CLAWGO_INTENT_HIGH_CONFIDENCE", p.intentHighConfidence) p.intentConfirmMinConfidence = envFloat64("CLAWGO_INTENT_CONFIRM_MIN_CONFIDENCE", p.intentConfirmMinConfidence) @@ -438,29 +675,48 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers provider = p } defaultModel := defaultModelFromModels(modelsByProxy[primaryProxy], provider) - policy := loadControlPolicyFromEnv(defaultControlPolicy()) + policy := loadControlPolicyFromConfig(defaultControlPolicy(), cfg.Agents.Defaults.RuntimeControl) + policy = applyLegacyControlPolicyEnvOverrides(policy) + runControlLex := loadRunControlLexiconFromConfig(cfg.Agents.Defaults.RuntimeControl) + parallelSafeTools, maxParallelCalls := loadToolParallelPolicyFromConfig(cfg.Agents.Defaults.RuntimeControl) + runStateTTL, runStateMax := loadRunStatePolicyFromConfig(cfg.Agents.Defaults.RuntimeControl) + // Keep compatibility with older env names. + runStateTTL = envDuration("CLAWGO_RUN_STATE_TTL", runStateTTL) + if runStateTTL < 1*time.Minute { + runStateTTL = defaultRunStateTTL + } + runStateMax = envInt("CLAWGO_RUN_STATE_MAX", runStateMax) + if runStateMax <= 0 { + runStateMax = defaultRunStateMaxEntries + } loop := &AgentLoop{ - bus: msgBus, - provider: provider, - providersByProxy: providersByProxy, - modelsByProxy: modelsByProxy, - proxy: primaryProxy, - proxyFallbacks: parseStringList(cfg.Agents.Defaults.ProxyFallbacks), - workspace: workspace, - model: defaultModel, - maxIterations: cfg.Agents.Defaults.MaxToolIterations, - sessions: sessionsManager, - contextBuilder: NewContextBuilder(workspace, cfg.Memory, func() []string { return toolsRegistry.GetSummaries() }), - tools: toolsRegistry, - orchestrator: orchestrator, - compactionCfg: cfg.Agents.Defaults.ContextCompaction, - llmCallTimeout: time.Duration(cfg.Providers.Proxy.TimeoutSec) * time.Second, - workers: make(map[string]*sessionWorker), - autoLearners: make(map[string]*autoLearner), - autonomyBySess: make(map[string]*autonomySession), - controlConfirm: make(map[string]pendingControlConfirmation), - controlPolicy: policy, + bus: msgBus, + provider: provider, + providersByProxy: providersByProxy, + modelsByProxy: modelsByProxy, + proxy: primaryProxy, + proxyFallbacks: parseStringList(cfg.Agents.Defaults.ProxyFallbacks), + workspace: workspace, + model: defaultModel, + maxIterations: cfg.Agents.Defaults.MaxToolIterations, + sessions: sessionsManager, + contextBuilder: NewContextBuilder(workspace, cfg.Memory, func() []string { return toolsRegistry.GetSummaries() }), + tools: toolsRegistry, + orchestrator: orchestrator, + compactionCfg: cfg.Agents.Defaults.ContextCompaction, + llmCallTimeout: time.Duration(cfg.Providers.Proxy.TimeoutSec) * time.Second, + workers: make(map[string]*sessionWorker), + autoLearners: make(map[string]*autoLearner), + autonomyBySess: make(map[string]*autonomySession), + controlConfirm: make(map[string]pendingControlConfirmation), + controlPolicy: policy, + runControlLex: runControlLex, + parallelSafeTools: parallelSafeTools, + maxParallelCalls: maxParallelCalls, + runStates: make(map[string]*runState), + runStateTTL: runStateTTL, + runStateMax: runStateMax, } logger.InfoCF("agent", "Control policy initialized", map[string]interface{}{ "intent_high_confidence": policy.intentHighConfidence, @@ -475,6 +731,15 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers "autonomy_max_pending_duration": policy.autonomyMaxPendingDuration.String(), "autonomy_max_consecutive_stalls": policy.autonomyMaxConsecutiveStalls, "autolearn_max_rounds_without_user": policy.autoLearnMaxRoundsWithoutUser, + "run_control_latest_keywords": len(runControlLex.latestKeywords), + "run_control_wait_keywords": len(runControlLex.waitKeywords), + "run_control_status_keywords": len(runControlLex.statusKeywords), + "run_control_run_keywords": len(runControlLex.runMentionKeywords), + "run_control_minute_units": len(runControlLex.minuteUnits), + "parallel_safe_tool_count": len(parallelSafeTools), + "tool_max_parallel_calls": maxParallelCalls, + "run_state_ttl": runStateTTL.String(), + "run_state_max": runStateMax, }) // Inject recursive run logic so subagent has full tool-calling capability. @@ -1180,6 +1445,392 @@ func shouldHandleControlIntents(msg bus.InboundMessage) bool { return !isSyntheticMessage(msg) } +func (al *AgentLoop) beginAgentRun(msg bus.InboundMessage) agentRunLifecycle { + seq := int64(1) + if al != nil { + seq = al.runSeq.Add(1) + } + acceptedAt := time.Now() + run := agentRunLifecycle{ + runID: fmt.Sprintf("run-%d-%d", acceptedAt.UnixNano(), seq), + acceptedAt: acceptedAt, + sessionKey: msg.SessionKey, + channel: msg.Channel, + senderID: msg.SenderID, + chatID: msg.ChatID, + synthetic: isSyntheticMessage(msg), + controlEligible: shouldHandleControlIntents(msg), + } + if al != nil { + al.controlMetricAdd(&al.controlStats.runAccepted, 1) + al.recordRunAccepted(run) + } + logger.InfoCF("agent", "Run lifecycle accepted", map[string]interface{}{ + "run_id": run.runID, + "session_key": run.sessionKey, + logger.FieldChannel: run.channel, + logger.FieldChatID: run.chatID, + "sender_id": run.senderID, + "synthetic": run.synthetic, + "control_eligible": run.controlEligible, + "accepted_at": run.acceptedAt.Format(time.RFC3339Nano), + }) + return run +} + +func (al *AgentLoop) finishAgentRun(run agentRunLifecycle, err error, response string, controlHandled bool) { + duration := time.Since(run.acceptedAt) + if al != nil { + if controlHandled { + al.controlMetricAdd(&al.controlStats.runControlHandled, 1) + } + switch { + case err == nil: + al.controlMetricAdd(&al.controlStats.runCompleted, 1) + case errors.Is(err, context.Canceled): + al.controlMetricAdd(&al.controlStats.runCanceled, 1) + default: + al.controlMetricAdd(&al.controlStats.runFailed, 1) + } + al.recordRunFinished(run, err, response, controlHandled) + } + + fields := map[string]interface{}{ + "run_id": run.runID, + "session_key": run.sessionKey, + logger.FieldChannel: run.channel, + logger.FieldChatID: run.chatID, + "duration_ms": duration.Milliseconds(), + "response_len": len(strings.TrimSpace(response)), + "control_handled": controlHandled, + } + if err != nil { + fields[logger.FieldError] = err.Error() + logger.WarnCF("agent", "Run lifecycle ended with error", fields) + return + } + logger.InfoCF("agent", "Run lifecycle ended", fields) +} + +func (al *AgentLoop) recordRunAccepted(run agentRunLifecycle) { + if al == nil || strings.TrimSpace(run.runID) == "" { + return + } + al.runStateMu.Lock() + defer al.runStateMu.Unlock() + rs := &runState{ + runID: run.runID, + sessionKey: run.sessionKey, + channel: run.channel, + chatID: run.chatID, + senderID: run.senderID, + synthetic: run.synthetic, + controlEligible: run.controlEligible, + status: runStatusRunning, + acceptedAt: run.acceptedAt, + startedAt: run.acceptedAt, + done: make(chan struct{}), + } + al.runStates[run.runID] = rs + al.pruneRunStatesLocked(time.Now()) +} + +func (al *AgentLoop) recordRunFinished(run agentRunLifecycle, err error, response string, controlHandled bool) { + if al == nil || strings.TrimSpace(run.runID) == "" { + return + } + al.runStateMu.Lock() + defer al.runStateMu.Unlock() + rs, ok := al.runStates[run.runID] + if !ok || rs == nil { + rs = &runState{ + runID: run.runID, + sessionKey: run.sessionKey, + channel: run.channel, + chatID: run.chatID, + senderID: run.senderID, + done: make(chan struct{}), + } + al.runStates[run.runID] = rs + } + rs.endedAt = time.Now() + rs.responseLen = len(strings.TrimSpace(response)) + rs.controlHandled = controlHandled + if err == nil { + rs.status = runStatusOK + } else if errors.Is(err, context.Canceled) { + rs.status = runStatusCanceled + rs.errMessage = err.Error() + } else { + rs.status = runStatusError + rs.errMessage = err.Error() + } + select { + case <-rs.done: + default: + close(rs.done) + } + al.pruneRunStatesLocked(rs.endedAt) +} + +func (al *AgentLoop) pruneRunStatesLocked(now time.Time) { + if al == nil || al.runStates == nil { + return + } + ttl := al.runStateTTL + if ttl <= 0 { + ttl = defaultRunStateTTL + } + maxEntries := al.runStateMax + if maxEntries <= 0 { + maxEntries = defaultRunStateMaxEntries + } + for id, rs := range al.runStates { + if rs == nil { + delete(al.runStates, id) + continue + } + if rs.endedAt.IsZero() { + continue + } + if now.Sub(rs.endedAt) > ttl { + delete(al.runStates, id) + } + } + if len(al.runStates) <= maxEntries { + return + } + type pair struct { + id string + ts time.Time + } + items := make([]pair, 0, len(al.runStates)) + for id, rs := range al.runStates { + ts := rs.startedAt + if !rs.endedAt.IsZero() { + ts = rs.endedAt + } + items = append(items, pair{id: id, ts: ts}) + } + sort.Slice(items, func(i, j int) bool { return items[i].ts.Before(items[j].ts) }) + removeCount := len(items) - maxEntries + for i := 0; i < removeCount; i++ { + delete(al.runStates, items[i].id) + } +} + +func (al *AgentLoop) getRunState(runID string) (runState, bool) { + if al == nil || strings.TrimSpace(runID) == "" { + return runState{}, false + } + al.runStateMu.Lock() + defer al.runStateMu.Unlock() + rs, ok := al.runStates[runID] + if !ok || rs == nil { + return runState{}, false + } + return *rs, true +} + +func (al *AgentLoop) latestRunState(sessionKey string) (runState, bool) { + if al == nil { + return runState{}, false + } + key := strings.TrimSpace(sessionKey) + al.runStateMu.Lock() + defer al.runStateMu.Unlock() + + var latest *runState + var latestAt time.Time + for _, rs := range al.runStates { + if rs == nil { + continue + } + if key != "" && rs.sessionKey != key { + continue + } + candidateAt := rs.acceptedAt + if !rs.startedAt.IsZero() { + candidateAt = rs.startedAt + } + if !rs.endedAt.IsZero() { + candidateAt = rs.endedAt + } + if latest == nil || candidateAt.After(latestAt) { + cp := *rs + latest = &cp + latestAt = candidateAt + } + } + if latest == nil { + return runState{}, false + } + return *latest, true +} + +func (al *AgentLoop) waitForRun(ctx context.Context, runID string) (runState, bool) { + if al == nil || strings.TrimSpace(runID) == "" { + return runState{}, false + } + al.runStateMu.Lock() + rs, ok := al.runStates[runID] + al.runStateMu.Unlock() + if !ok || rs == nil { + return runState{}, false + } + select { + case <-ctx.Done(): + return runState{}, false + case <-rs.done: + return al.getRunState(runID) + } +} + +func detectRunControlIntent(content string) (runControlIntent, bool) { + return detectRunControlIntentWithLexicon(content, defaultRunControlLexicon()) +} + +func detectRunControlIntentWithLexicon(content string, lex runControlLexicon) (runControlIntent, bool) { + text := strings.TrimSpace(content) + if text == "" { + return runControlIntent{}, false + } + if strings.HasPrefix(text, "/") { + return runControlIntent{}, false + } + + lower := strings.ToLower(text) + intent := runControlIntent{ + timeout: defaultRunWaitTimeout, + } + if m := runIDPattern.FindStringSubmatch(text); len(m) > 1 { + intent.runID = strings.ToLower(strings.TrimSpace(m[1])) + } + intent.latest = containsAnySubstring(lower, lex.latestKeywords...) + intent.wait = containsAnySubstring(lower, lex.waitKeywords...) + isStatusQuery := containsAnySubstring(lower, lex.statusKeywords...) + isRunMentioned := containsAnySubstring(lower, lex.runMentionKeywords...) + if !intent.wait && !isStatusQuery { + if intent.runID == "" || !isRunMentioned { + return runControlIntent{}, false + } + } + if intent.runID == "" && !intent.latest { + return runControlIntent{}, false + } + if intent.wait { + intent.timeout = parseRunWaitTimeoutWithLexicon(text, lex) + } + return intent, true +} + +func parseRunWaitTimeout(content string) time.Duration { + return parseRunWaitTimeoutWithLexicon(content, defaultRunControlLexicon()) +} + +func parseRunWaitTimeoutWithLexicon(content string, lex runControlLexicon) time.Duration { + timeout := defaultRunWaitTimeout + matches := runWaitTimeoutPattern.FindStringSubmatch(content) + if len(matches) < 3 { + return timeout + } + n, err := strconv.Atoi(matches[1]) + if err != nil || n <= 0 { + return timeout + } + unit := strings.ToLower(strings.TrimSpace(matches[2])) + if _, isMinute := lex.minuteUnits[unit]; isMinute { + timeout = time.Duration(n) * time.Minute + } else { + timeout = time.Duration(n) * time.Second + } + if timeout < minRunWaitTimeout { + return minRunWaitTimeout + } + if timeout > maxRunWaitTimeout { + return maxRunWaitTimeout + } + return timeout +} + +func containsAnySubstring(text string, values ...string) bool { + for _, value := range values { + if value != "" && strings.Contains(text, value) { + return true + } + } + return false +} + +func formatRunStateReport(rs runState) string { + lines := []string{ + fmt.Sprintf("Run ID: %s", rs.runID), + fmt.Sprintf("Status: %s", rs.status), + fmt.Sprintf("Session: %s", rs.sessionKey), + fmt.Sprintf("Accepted At: %s", rs.acceptedAt.Format(time.RFC3339)), + } + if !rs.startedAt.IsZero() { + lines = append(lines, fmt.Sprintf("Started At: %s", rs.startedAt.Format(time.RFC3339))) + } + if !rs.endedAt.IsZero() { + lines = append(lines, fmt.Sprintf("Ended At: %s", rs.endedAt.Format(time.RFC3339))) + lines = append(lines, fmt.Sprintf("Duration: %s", rs.endedAt.Sub(rs.startedAt).Truncate(time.Millisecond))) + } else if !rs.startedAt.IsZero() { + lines = append(lines, fmt.Sprintf("Elapsed: %s", time.Since(rs.startedAt).Truncate(time.Second))) + } + lines = append(lines, fmt.Sprintf("Control Handled: %v", rs.controlHandled)) + lines = append(lines, fmt.Sprintf("Response Length: %d", rs.responseLen)) + if strings.TrimSpace(rs.errMessage) != "" { + lines = append(lines, fmt.Sprintf("Error: %s", rs.errMessage)) + } + return strings.Join(lines, "\n") +} + +func (al *AgentLoop) executeRunControlIntent(ctx context.Context, sessionKey string, intent runControlIntent) string { + var ( + rs runState + found bool + ) + if intent.latest { + rs, found = al.latestRunState(sessionKey) + } else { + rs, found = al.getRunState(intent.runID) + } + if !found { + return al.naturalizeUserFacingText(ctx, "No matching run state found. Try specifying a run ID or asking for the latest run status.") + } + + if intent.wait && (rs.status == runStatusAccepted || rs.status == runStatusRunning) { + waitCtx, cancel := context.WithTimeout(ctx, intent.timeout) + defer cancel() + if waited, ok := al.waitForRun(waitCtx, rs.runID); ok { + rs = waited + } else { + if latest, ok := al.getRunState(rs.runID); ok { + rs = latest + } + fallback := fmt.Sprintf("Run %s is still %s after waiting %s.\n%s", rs.runID, rs.status, intent.timeout.Truncate(time.Second), formatRunStateReport(rs)) + return al.naturalizeUserFacingText(ctx, fallback) + } + } + return al.naturalizeUserFacingText(ctx, formatRunStateReport(rs)) +} + +func (al *AgentLoop) handleNaturalRunControl(ctx context.Context, msg bus.InboundMessage) (bool, string) { + intent, ok := detectRunControlIntentWithLexicon(msg.Content, al.effectiveRunControlLexicon()) + if !ok { + return false, "" + } + return true, al.executeRunControlIntent(ctx, msg.SessionKey, intent) +} + +func (al *AgentLoop) effectiveRunControlLexicon() runControlLexicon { + if al == nil || len(al.runControlLex.latestKeywords) == 0 || len(al.runControlLex.minuteUnits) == 0 { + return defaultRunControlLexicon() + } + return al.runControlLex +} + func (al *AgentLoop) controlMetricAdd(counter *int64, delta int64) { if al == nil || counter == nil { return @@ -1203,8 +1854,16 @@ func (al *AgentLoop) logControlStatsSnapshot() { al.controlConfirmMu.Lock() pendingConfirm := len(al.controlConfirm) al.controlConfirmMu.Unlock() + al.runStateMu.Lock() + runStatesTotal := len(al.runStates) + al.runStateMu.Unlock() stats := map[string]interface{}{ + "run_accepted": atomic.LoadInt64(&al.controlStats.runAccepted), + "run_completed": atomic.LoadInt64(&al.controlStats.runCompleted), + "run_failed": atomic.LoadInt64(&al.controlStats.runFailed), + "run_canceled": atomic.LoadInt64(&al.controlStats.runCanceled), + "run_control_handled": atomic.LoadInt64(&al.controlStats.runControlHandled), "intent_autonomy_matched": atomic.LoadInt64(&al.controlStats.intentAutonomyMatched), "intent_autonomy_needs_confirm": atomic.LoadInt64(&al.controlStats.intentAutonomyNeedsConfirm), "intent_autonomy_rejected": atomic.LoadInt64(&al.controlStats.intentAutonomyRejected), @@ -1222,6 +1881,7 @@ func (al *AgentLoop) logControlStatsSnapshot() { "autonomy_active_sessions": autonomyActive, "autolearn_active_sessions": autoLearnActive, "pending_control_confirm_sessions": pendingConfirm, + "run_states_total": runStatesTotal, } logger.InfoCF("agent", "Control runtime snapshot", stats) } @@ -1271,6 +1931,44 @@ func (al *AgentLoop) executeAutoLearnIntent(ctx context.Context, msg bus.Inbound } } +func (al *AgentLoop) handleControlPlane(ctx context.Context, msg bus.InboundMessage) (handled bool, response string, err error) { + if !shouldHandleControlIntents(msg) { + return false, "", nil + } + + // Deterministic commands first. + if handled, result, cmdErr := al.handleSlashCommand(ctx, msg); handled { + return true, result, cmdErr + } + + al.noteAutonomyUserActivity(msg) + + if handled, result := al.handlePendingControlConfirmation(ctx, msg); handled { + return true, result, nil + } + if handled, result := al.handleNaturalRunControl(ctx, msg); handled { + return true, result, nil + } + + if intent, outcome := al.detectAutonomyIntent(ctx, msg.Content); outcome.matched { + al.clearPendingControlConfirmation(msg.SessionKey) + return true, al.executeAutonomyIntent(ctx, msg, intent), nil + } else if outcome.needsConfirm { + al.storePendingAutonomyConfirmation(msg.SessionKey, msg.Content, intent, outcome.confidence) + return true, al.naturalizeUserFacingText(ctx, al.formatAutonomyConfirmationPrompt(intent)), nil + } + + if intent, outcome := al.detectAutoLearnIntent(ctx, msg.Content); outcome.matched { + al.clearPendingControlConfirmation(msg.SessionKey) + return true, al.executeAutoLearnIntent(ctx, msg, intent), nil + } else if outcome.needsConfirm { + al.storePendingAutoLearnConfirmation(msg.SessionKey, msg.Content, intent, outcome.confidence) + return true, al.naturalizeUserFacingText(ctx, al.formatAutoLearnConfirmationPrompt(intent)), nil + } + + return false, "", nil +} + func (al *AgentLoop) handlePendingControlConfirmation(ctx context.Context, msg bus.InboundMessage) (bool, string) { pending, ok := al.getPendingControlConfirmation(msg.SessionKey) if !ok { @@ -1652,8 +2350,13 @@ func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey stri return al.processMessage(ctx, msg) } -func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { +func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (response string, err error) { ctx = withUserLanguageHint(ctx, msg.SessionKey, msg.Content) + run := al.beginAgentRun(msg) + controlHandled := false + defer func() { + al.finishAgentRun(run, err, response, controlHandled) + }() // Add message preview to log preview := truncate(msg.Content, 80) @@ -1669,50 +2372,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) defer al.finishAutonomyRound(msg.SessionKey) } - controlEligible := shouldHandleControlIntents(msg) - // Route system messages to processSystemMessage if msg.Channel == "system" { return al.processSystemMessage(ctx, msg) } - // Built-in slash commands (deterministic, no LLM roundtrip) - if controlEligible { - if handled, result, err := al.handleSlashCommand(ctx, msg); handled { - return result, err - } - } - - if controlEligible { - al.noteAutonomyUserActivity(msg) - } - - if controlEligible { - if handled, response := al.handlePendingControlConfirmation(ctx, msg); handled { - return response, nil - } - } - - if controlEligible { - if intent, outcome := al.detectAutonomyIntent(ctx, msg.Content); outcome.matched { - al.clearPendingControlConfirmation(msg.SessionKey) - return al.executeAutonomyIntent(ctx, msg, intent), nil - } else if outcome.needsConfirm { - al.storePendingAutonomyConfirmation(msg.SessionKey, msg.Content, intent, outcome.confidence) - return al.naturalizeUserFacingText(ctx, al.formatAutonomyConfirmationPrompt(intent)), nil - } - - if intent, outcome := al.detectAutoLearnIntent(ctx, msg.Content); outcome.matched { - al.clearPendingControlConfirmation(msg.SessionKey) - return al.executeAutoLearnIntent(ctx, msg, intent), nil - } else if outcome.needsConfirm { - al.storePendingAutoLearnConfirmation(msg.SessionKey, msg.Content, intent, outcome.confidence) - return al.naturalizeUserFacingText(ctx, al.formatAutoLearnConfirmationPrompt(intent)), nil - } + if handled, controlResponse, controlErr := al.handleControlPlane(ctx, msg); handled { + controlHandled = true + return controlResponse, controlErr } directives := parseTaskExecutionDirectives(msg.Content) - if controlEligible { + if run.controlEligible { if inferred, ok := al.inferTaskExecutionDirectives(ctx, msg.Content); ok { // Explicit /run/@run command always has higher priority than inferred directives. if !isExplicitRunCommand(msg.Content) { @@ -1727,7 +2398,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) if strings.TrimSpace(userPrompt) == "" { userPrompt = msg.Content } - if al.isAutonomyEnabled(msg.SessionKey) && controlEligible { + if al.isAutonomyEnabled(msg.SessionKey) && run.controlEligible { userPrompt = buildAutonomyTaskPrompt(userPrompt) } @@ -1784,6 +2455,10 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) } return "", err } + finalContent, repairPasses := al.runSelfRepairIfNeeded(ctx, msg.SessionKey, userPrompt, messages, finalContent, progress) + if repairPasses > 0 { + iteration += repairPasses + } if finalContent == "" { finalContent = "Done." @@ -1930,72 +2605,22 @@ func (al *AgentLoop) runLLMToolLoop( ) (string, int, error) { messages = sanitizeMessagesForToolCalling(messages) - iteration := 0 - var finalContent string - var lastToolResult string + state := toolLoopState{} - for iteration < al.maxIterations { - iteration++ + for state.iteration < al.maxIterations { + state.iteration++ + iteration := state.iteration if progress != nil { progress.Publish(3, 5, "execution", fmt.Sprintf("Running iteration %d.", iteration)) } - if !systemMode { - logger.DebugCF("agent", "LLM iteration", - map[string]interface{}{ - "iteration": iteration, - "max": al.maxIterations, - }) - } - providerToolDefs, err := buildProviderToolDefs(al.tools.GetDefinitions()) if err != nil { return "", iteration, fmt.Errorf("invalid tool definition: %w", err) } - messages = sanitizeMessagesForToolCalling(messages) - - systemPromptLen := 0 - if len(messages) > 0 { - systemPromptLen = len(messages[0].Content) - } - logger.DebugCF("agent", "LLM request", - map[string]interface{}{ - "iteration": iteration, - "model": al.model, - "messages_count": len(messages), - "tools_count": len(providerToolDefs), - "max_tokens": 8192, - "temperature": 0.7, - "system_prompt_len": systemPromptLen, - }) - logger.DebugCF("agent", "Full LLM request", - map[string]interface{}{ - "iteration": iteration, - "messages_json": formatMessagesForLog(messages), - "tools_json": formatToolsForLog(providerToolDefs), - }) - - llmStart := time.Now() - llmCtx, cancelLLM := context.WithTimeout(ctx, al.llmCallTimeout) - response, err := al.callLLMWithModelFallback(llmCtx, messages, providerToolDefs, map[string]interface{}{ - "max_tokens": 8192, - "temperature": 0.7, - }) - cancelLLM() - llmElapsed := time.Since(llmStart) - + response, llmElapsed, err := al.planToolCalls(ctx, messages, providerToolDefs, iteration, systemMode) if err != nil { - errLog := "LLM call failed" - if systemMode { - errLog = "LLM call failed in system message" - } - logger.ErrorCF("agent", errLog, - map[string]interface{}{ - "iteration": iteration, - logger.FieldError: err.Error(), - "elapsed": llmElapsed.String(), - }) return "", iteration, fmt.Errorf("LLM call failed: %w", err) } @@ -2011,17 +2636,40 @@ func (al *AgentLoop) runLLMToolLoop( }) if len(response.ToolCalls) == 0 { - finalContent = response.Content + state.finalContent = response.Content + state.consecutiveAllToolErrorRounds = 0 + state.repeatedToolCallRounds = 0 + state.lastToolCallSignature = "" if !systemMode { logger.InfoCF("agent", "LLM response without tool calls (direct answer)", map[string]interface{}{ "iteration": iteration, - logger.FieldAssistantContentLength: len(finalContent), + logger.FieldAssistantContentLength: len(state.finalContent), }) } break } + currentSignature := toolCallsSignature(response.ToolCalls) + if currentSignature != "" && currentSignature == state.lastToolCallSignature { + state.repeatedToolCallRounds++ + } else { + state.repeatedToolCallRounds = 0 + state.lastToolCallSignature = currentSignature + } + if state.repeatedToolCallRounds >= toolLoopRepeatSignatureThreshold { + logger.WarnCF("agent", "Repeated tool-call pattern detected, forcing finalization", map[string]interface{}{ + "iteration": iteration, + "session_key": sessionKey, + "repeat_round": state.repeatedToolCallRounds, + }) + messages = append(messages, providers.Message{ + Role: "user", + Content: "You are repeating the same tool calls. Stop calling tools and provide the best final answer now, including blockers and the minimum user input needed.", + }) + break + } + toolNames := make([]string, 0, len(response.ToolCalls)) for _, tc := range response.ToolCalls { toolNames = append(toolNames, tc.Name) @@ -2033,98 +2681,1401 @@ func (al *AgentLoop) runLLMToolLoop( "iteration": iteration, }) - assistantMsg := providers.Message{ - Role: "assistant", - Content: response.Content, - } - for _, tc := range response.ToolCalls { - argumentsJSON, _ := json.Marshal(tc.Arguments) - assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ - ID: tc.ID, - Type: "function", - Function: &providers.FunctionCall{ - Name: tc.Name, - Arguments: string(argumentsJSON), - }, + budget := al.computeToolLoopBudget(state) + outcome := al.actToolCalls(ctx, response.Content, response.ToolCalls, &messages, sessionKey, iteration, budget, systemMode, progress) + state.lastToolResult = outcome.lastToolResult + if summary := summarizeToolActOutcome(outcome); summary != "" { + messages = append(messages, providers.Message{ + Role: "user", + Content: "Structured tool execution summary (for decision making): " + summary, }) } - messages = append(messages, assistantMsg) - al.sessions.AddMessageFull(sessionKey, assistantMsg) + if outcome.executedCalls > 0 && outcome.roundToolErrors == outcome.executedCalls { + state.consecutiveAllToolErrorRounds++ + } else { + state.consecutiveAllToolErrorRounds = 0 + } + if outcome.truncated && !systemMode { + logger.WarnCF("agent", "Tool execution truncated by budget", map[string]interface{}{ + "iteration": iteration, + "session_key": sessionKey, + "executed_calls": outcome.executedCalls, + "dropped_calls": outcome.droppedCalls, + }) + } + if state.consecutiveAllToolErrorRounds >= toolLoopAllErrorRoundsThreshold { + logger.WarnCF("agent", "Consecutive all-tool-error rounds detected, forcing finalization", map[string]interface{}{ + "iteration": iteration, + "session_key": sessionKey, + "error_rounds": state.consecutiveAllToolErrorRounds, + "tools_in_last_round": outcome.executedCalls, + }) + messages = append(messages, providers.Message{ + Role: "user", + Content: "All recent tool calls failed. Stop calling tools and provide a final answer with diagnosis, fallback suggestions, and what input/permission is missing.", + }) + break + } + if outcome.blockedLikely { + finalResp, ferr := al.finalizeToolLoop(ctx, append(messages, providers.Message{ + Role: "user", + Content: "Tool errors indicate hard blockers (permission/input/resource). Stop calling tools and provide diagnosis plus exact minimum user action needed.", + })) + if ferr == nil && finalResp != nil && strings.TrimSpace(finalResp.Content) != "" { + state.finalContent = finalResp.Content + } + break + } - for _, tc := range response.ToolCalls { + if al.shouldTriggerReflection(state, outcome) { + decision, reason, confidence := al.reflectToolLoopProgress(ctx, messages) + state.lastReflectDecision = decision + state.lastReflectConfidence = confidence + state.lastReflectIteration = iteration if !systemMode { - argsJSON, _ := json.Marshal(tc.Arguments) - argsPreview := truncate(string(argsJSON), 200) - logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), - map[string]interface{}{ - "tool": tc.Name, - "iteration": iteration, - }) + logger.DebugCF("agent", "Tool-loop reflection", map[string]interface{}{ + "iteration": iteration, + "decision": decision, + "reason": reason, + "confidence": confidence, + }) } - - result, err := al.tools.Execute(ctx, tc.Name, tc.Arguments) - if err != nil { - result = fmt.Sprintf("Error: %v", err) - } - if progress != nil { - if err != nil { - progress.Publish(3, 5, "execution", fmt.Sprintf("Tool %s failed: %v", tc.Name, err)) - } else { - progress.Publish(3, 5, "execution", fmt.Sprintf("Tool %s completed.", tc.Name)) + switch decision { + case "done": + finalResp, ferr := al.finalizeToolLoop(ctx, append(messages, providers.Message{ + Role: "user", + Content: fmt.Sprintf("Reflection indicates completion (confidence %.2f). Provide the final user-facing answer now without tools. Reason: %s", confidence, reason), + })) + if ferr == nil && finalResp != nil && strings.TrimSpace(finalResp.Content) != "" { + state.finalContent = finalResp.Content + break } + messages = append(messages, providers.Message{ + Role: "user", + Content: fmt.Sprintf("Reflection indicates completion. Provide final answer now without tools. Reason: %s", reason), + }) + break + case "blocked": + finalResp, ferr := al.finalizeToolLoop(ctx, append(messages, providers.Message{ + Role: "user", + Content: fmt.Sprintf("Reflection indicates blocked progress (confidence %.2f). Stop calling tools and provide diagnosis, blockers, and minimum user input needed. Reason: %s", confidence, reason), + })) + if ferr == nil && finalResp != nil && strings.TrimSpace(finalResp.Content) != "" { + state.finalContent = finalResp.Content + break + } + messages = append(messages, providers.Message{ + Role: "user", + Content: fmt.Sprintf("Blocked progress detected. Stop calling tools and provide diagnosis plus minimum needed user action. Reason: %s", reason), + }) + break + default: + messages = append(messages, providers.Message{ + Role: "user", + Content: fmt.Sprintf("Continue execution with minimal next-step tools only. Avoid repetition. Reflection reason: %s", reason), + }) } - lastToolResult = result - - toolResultMsg := providers.Message{ - Role: "tool", - Content: result, - ToolCallID: tc.ID, + if state.finalContent != "" || decision == "done" || decision == "blocked" { + break } - messages = append(messages, toolResultMsg) - al.sessions.AddMessageFull(sessionKey, toolResultMsg) } } // When max iterations are reached without a direct answer, ask once more without tools. // This avoids returning placeholder text to end users. - if finalContent == "" && len(messages) > 0 { - if !systemMode && iteration >= al.maxIterations { + if state.finalContent == "" && len(messages) > 0 { + if !systemMode && state.iteration >= al.maxIterations { logger.WarnCF("agent", "Max tool iterations reached without final answer; forcing finalization pass", map[string]interface{}{ - "iteration": iteration, + "iteration": state.iteration, "max": al.maxIterations, "session_key": sessionKey, }) } - finalizeMessages := append([]providers.Message{}, messages...) - finalizeMessages = append(finalizeMessages, providers.Message{ - Role: "user", - Content: "Now provide your final response to the user based on the completed tool results. Do not call any tools.", - }) - finalizeMessages = sanitizeMessagesForToolCalling(finalizeMessages) - - llmCtx, cancelLLM := context.WithTimeout(ctx, al.llmCallTimeout) - finalResp, err := al.callLLMWithModelFallback(llmCtx, finalizeMessages, nil, map[string]interface{}{ - "max_tokens": 1024, - "temperature": 0.3, - }) - cancelLLM() + finalResp, err := al.finalizeToolLoop(ctx, messages) if err != nil { logger.WarnCF("agent", "Finalization pass failed", map[string]interface{}{ - "iteration": iteration, + "iteration": state.iteration, "session_key": sessionKey, logger.FieldError: err.Error(), }) } else if strings.TrimSpace(finalResp.Content) != "" { - finalContent = finalResp.Content + state.finalContent = finalResp.Content } } - if finalContent == "" { - finalContent = strings.TrimSpace(lastToolResult) + if state.finalContent == "" { + state.finalContent = strings.TrimSpace(state.lastToolResult) } - return finalContent, iteration, nil + return state.finalContent, state.iteration, nil +} + +type toolLoopState struct { + iteration int + finalContent string + lastToolResult string + lastToolCallSignature string + repeatedToolCallRounds int + consecutiveAllToolErrorRounds int + lastReflectDecision string + lastReflectConfidence float64 + lastReflectIteration int +} + +type toolActOutcome struct { + roundToolErrors int + lastToolResult string + executedCalls int + droppedCalls int + truncated bool + emptyResults int + retryableErrors int + hardErrors int + blockedLikely bool + records []toolExecutionRecord +} + +type loopReflectResponse struct { + Decision string `json:"decision"` + Reason string `json:"reason"` + Confidence float64 `json:"confidence"` +} + +type finalizeQualityResponse struct { + Score float64 `json:"score"` +} + +type localReflectSignal struct { + decision string + reason string + confidence float64 + uncertain bool +} + +type selfRepairDecision struct { + NeedsRepair bool `json:"needs_repair"` + Reason string `json:"reason"` + RepairPrompt string `json:"repair_prompt"` + Confidence float64 `json:"confidence"` +} + +type selfRepairMemory struct { + promptsUsed map[string]struct{} + outputsSeen map[string]struct{} + failureReason []string +} + +type toolExecutionRecord struct { + Tool string `json:"tool"` + Status string `json:"status"` + ErrorType string `json:"error_type,omitempty"` + Retryable bool `json:"retryable,omitempty"` + ErrMessage string `json:"error,omitempty"` +} + +type toolLoopBudget struct { + maxCallsPerIteration int + singleCallTimeout time.Duration + maxActDuration time.Duration +} + +type toolCallExecResult struct { + index int + call providers.ToolCall + result string + err error +} + +func (al *AgentLoop) planToolCalls( + ctx context.Context, + messages []providers.Message, + providerToolDefs []providers.ToolDefinition, + iteration int, + systemMode bool, +) (*providers.LLMResponse, time.Duration, error) { + if !systemMode { + logger.DebugCF("agent", "LLM iteration", map[string]interface{}{ + "iteration": iteration, + "max": al.maxIterations, + }) + } + + messages = sanitizeMessagesForToolCalling(messages) + systemPromptLen := 0 + if len(messages) > 0 { + systemPromptLen = len(messages[0].Content) + } + logger.DebugCF("agent", "LLM request", map[string]interface{}{ + "iteration": iteration, + "model": al.model, + "messages_count": len(messages), + "tools_count": len(providerToolDefs), + "max_tokens": 8192, + "temperature": 0.7, + "system_prompt_len": systemPromptLen, + }) + logger.DebugCF("agent", "Full LLM request", map[string]interface{}{ + "iteration": iteration, + "messages_json": formatMessagesForLog(messages), + "tools_json": formatToolsForLog(providerToolDefs), + }) + + llmStart := time.Now() + llmCtx, cancelLLM := context.WithTimeout(ctx, al.llmCallTimeout) + response, err := al.callLLMWithModelFallback(llmCtx, messages, providerToolDefs, map[string]interface{}{ + "max_tokens": 8192, + "temperature": 0.7, + }) + cancelLLM() + llmElapsed := time.Since(llmStart) + if err != nil { + errLog := "LLM call failed" + if systemMode { + errLog = "LLM call failed in system message" + } + logger.ErrorCF("agent", errLog, map[string]interface{}{ + "iteration": iteration, + logger.FieldError: err.Error(), + "elapsed": llmElapsed.String(), + }) + return nil, llmElapsed, err + } + return response, llmElapsed, nil +} + +func (al *AgentLoop) actToolCalls( + ctx context.Context, + assistantContent string, + toolCalls []providers.ToolCall, + messages *[]providers.Message, + sessionKey string, + iteration int, + budget toolLoopBudget, + systemMode bool, + progress *stageReporter, +) toolActOutcome { + outcome := toolActOutcome{} + if len(toolCalls) == 0 { + return outcome + } + execCalls := toolCalls + maxCalls := budget.maxCallsPerIteration + if maxCalls <= 0 { + maxCalls = toolLoopMaxCallsPerIteration + } + if len(execCalls) > maxCalls { + outcome.truncated = true + outcome.droppedCalls = len(execCalls) - maxCalls + execCalls = execCalls[:maxCalls] + } + + assistantMsg := providers.Message{ + Role: "assistant", + Content: assistantContent, + } + for _, tc := range execCalls { + argumentsJSON, _ := json.Marshal(tc.Arguments) + assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ + ID: tc.ID, + Type: "function", + Function: &providers.FunctionCall{ + Name: tc.Name, + Arguments: string(argumentsJSON), + }, + }) + } + *messages = append(*messages, assistantMsg) + al.sessions.AddMessageFull(sessionKey, assistantMsg) + + start := time.Now() + maxActDuration := budget.maxActDuration + if maxActDuration <= 0 { + maxActDuration = toolLoopMaxActDuration + } + singleTimeout := budget.singleCallTimeout + if singleTimeout <= 0 { + singleTimeout = toolLoopSingleCallTimeout + } + roundCtx, cancelRound := context.WithTimeout(ctx, maxActDuration) + defer cancelRound() + + parallel := al.shouldRunToolCallsInParallel(execCalls) + results := al.executeToolCalls(roundCtx, execCalls, iteration, singleTimeout, parallel, systemMode, progress) + if time.Since(start) >= maxActDuration && len(results) < len(execCalls) { + outcome.truncated = true + outcome.droppedCalls += len(execCalls) - len(results) + } + + for i, execRes := range results { + tc := execRes.call + result := execRes.result + err := execRes.err + record := toolExecutionRecord{ + Tool: tc.Name, + Status: "ok", + } + if err != nil { + result = fmt.Sprintf("Error: %v", err) + outcome.roundToolErrors++ + record.Status = "error" + record.ErrorType, record.Retryable, outcome.blockedLikely = classifyToolExecutionError(err, outcome.blockedLikely) + record.ErrMessage = truncate(err.Error(), 240) + if record.Retryable { + outcome.retryableErrors++ + } else { + outcome.hardErrors++ + } + } else if strings.TrimSpace(result) == "" { + outcome.emptyResults++ + record.Status = "empty" + } + if progress != nil { + if err != nil { + progress.Publish(3, 5, "execution", fmt.Sprintf("Tool %s failed: %v", tc.Name, err)) + } else { + progress.Publish(3, 5, "execution", fmt.Sprintf("Tool %s completed.", tc.Name)) + } + } + outcome.lastToolResult = result + + toolResultMsg := providers.Message{ + Role: "tool", + Content: result, + ToolCallID: tc.ID, + } + *messages = append(*messages, toolResultMsg) + if shouldPersistToolResultRecord(record, i, len(results)) { + al.sessions.AddMessageFull(sessionKey, toolResultMsg) + } + outcome.executedCalls++ + outcome.records = append(outcome.records, record) + } + return outcome +} + +func (al *AgentLoop) executeToolCalls( + ctx context.Context, + execCalls []providers.ToolCall, + iteration int, + singleTimeout time.Duration, + parallel bool, + systemMode bool, + progress *stageReporter, +) []toolCallExecResult { + if parallel { + return al.executeToolCallsBatchedParallel(ctx, execCalls, iteration, singleTimeout, systemMode, progress) + } + return al.executeToolCallsSerial(ctx, execCalls, iteration, singleTimeout, systemMode, progress) +} + +func (al *AgentLoop) executeToolCallsBatchedParallel( + ctx context.Context, + execCalls []providers.ToolCall, + iteration int, + singleTimeout time.Duration, + systemMode bool, + progress *stageReporter, +) []toolCallExecResult { + batches := al.buildParallelBatches(execCalls) + results := make([]toolCallExecResult, 0, len(execCalls)) + for _, batch := range batches { + select { + case <-ctx.Done(): + return results + default: + } + if len(batch) <= 1 { + if len(batch) == 1 { + results = append(results, al.executeSingleToolCall(ctx, len(results), batch[0], iteration, singleTimeout, systemMode, progress)) + } + continue + } + out := al.executeToolCallsParallel(ctx, batch, iteration, singleTimeout, systemMode, progress) + results = append(results, out...) + } + return results +} + +func (al *AgentLoop) executeToolCallsSerial( + ctx context.Context, + execCalls []providers.ToolCall, + iteration int, + singleTimeout time.Duration, + systemMode bool, + progress *stageReporter, +) []toolCallExecResult { + results := make([]toolCallExecResult, 0, len(execCalls)) + for i, tc := range execCalls { + select { + case <-ctx.Done(): + return results + default: + } + res := al.executeSingleToolCall(ctx, i, tc, iteration, singleTimeout, systemMode, progress) + results = append(results, res) + } + return results +} + +func (al *AgentLoop) executeToolCallsParallel( + ctx context.Context, + execCalls []providers.ToolCall, + iteration int, + singleTimeout time.Duration, + systemMode bool, + progress *stageReporter, +) []toolCallExecResult { + results := make([]toolCallExecResult, len(execCalls)) + limit := al.maxToolParallelCalls() + if limit <= 1 { + return al.executeToolCallsSerial(ctx, execCalls, iteration, singleTimeout, systemMode, progress) + } + if len(execCalls) < limit { + limit = len(execCalls) + } + sem := make(chan struct{}, limit) + var wg sync.WaitGroup + for i, tc := range execCalls { + select { + case <-ctx.Done(): + goto wait + default: + } + sem <- struct{}{} + wg.Add(1) + go func(i int, tc providers.ToolCall) { + defer wg.Done() + defer func() { <-sem }() + results[i] = al.executeSingleToolCall(ctx, i, tc, iteration, singleTimeout, systemMode, progress) + }(i, tc) + } +wait: + wg.Wait() + out := make([]toolCallExecResult, 0, len(execCalls)) + for i := range results { + if strings.TrimSpace(results[i].call.Name) == "" { + continue + } + out = append(out, results[i]) + } + return out +} + +func (al *AgentLoop) buildParallelBatches(execCalls []providers.ToolCall) [][]providers.ToolCall { + if len(execCalls) == 0 { + return nil + } + batches := make([][]providers.ToolCall, 0, len(execCalls)) + current := make([]providers.ToolCall, 0, len(execCalls)) + used := map[string]struct{}{} + + flush := func() { + if len(current) == 0 { + return + } + batch := append([]providers.ToolCall(nil), current...) + batches = append(batches, batch) + current = current[:0] + used = map[string]struct{}{} + } + + for _, tc := range execCalls { + keys := al.toolResourceKeys(tc.Name, tc.Arguments) + if len(current) > 0 && hasResourceKeyConflict(used, keys) { + flush() + } + current = append(current, tc) + for _, k := range keys { + used[k] = struct{}{} + } + } + flush() + return batches +} + +func hasResourceKeyConflict(used map[string]struct{}, keys []string) bool { + if len(keys) == 0 || len(used) == 0 { + return false + } + for _, k := range keys { + if _, ok := used[k]; ok { + return true + } + } + return false +} + +func (al *AgentLoop) toolResourceKeys(name string, args map[string]interface{}) []string { + raw := strings.TrimSpace(name) + lower := strings.ToLower(raw) + if raw == "" || al == nil || al.tools == nil { + return nil + } + tool, ok := al.tools.Get(raw) + if !ok && lower != raw { + tool, ok = al.tools.Get(lower) + } + if !ok || tool == nil { + return nil + } + rs, ok := tool.(tools.ResourceScopedTool) + if !ok { + return nil + } + return normalizeResourceKeys(rs.ResourceKeys(args)) +} + +func normalizeResourceKeys(keys []string) []string { + if len(keys) == 0 { + return nil + } + out := make([]string, 0, len(keys)) + seen := make(map[string]struct{}, len(keys)) + for _, k := range keys { + n := strings.ToLower(strings.TrimSpace(k)) + if n == "" { + continue + } + if _, ok := seen[n]; ok { + continue + } + seen[n] = struct{}{} + out = append(out, n) + } + return out +} + +func (al *AgentLoop) executeSingleToolCall( + ctx context.Context, + index int, + tc providers.ToolCall, + iteration int, + singleTimeout time.Duration, + systemMode bool, + progress *stageReporter, +) toolCallExecResult { + if !systemMode { + argsJSON, _ := json.Marshal(tc.Arguments) + argsPreview := truncate(string(argsJSON), 200) + logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), map[string]interface{}{ + "tool": tc.Name, + "iteration": iteration, + }) + } + + toolCtx, cancelTool := context.WithTimeout(ctx, singleTimeout) + defer cancelTool() + result, err := al.tools.Execute(toolCtx, tc.Name, tc.Arguments) + if err != nil { + result = fmt.Sprintf("Error: %v", err) + } + if progress != nil { + if err != nil { + progress.Publish(3, 5, "execution", fmt.Sprintf("Tool %s failed: %v", tc.Name, err)) + } else { + progress.Publish(3, 5, "execution", fmt.Sprintf("Tool %s completed.", tc.Name)) + } + } + return toolCallExecResult{ + index: index, + call: tc, + result: result, + err: err, + } +} + +func (al *AgentLoop) shouldRunToolCallsInParallel(calls []providers.ToolCall) bool { + if len(calls) <= 1 { + return false + } + for _, c := range calls { + if !al.isParallelSafeTool(c.Name) { + return false + } + } + return true +} + +func (al *AgentLoop) isParallelSafeTool(name string) bool { + raw := strings.TrimSpace(name) + tool := strings.ToLower(raw) + if raw == "" { + return false + } + if al != nil && al.tools != nil { + if t, ok := al.tools.Get(raw); ok { + if ps, ok := t.(tools.ParallelSafeTool); ok { + return ps.ParallelSafe() + } + } + if raw != tool { + if t, ok := al.tools.Get(tool); ok { + if ps, ok := t.(tools.ParallelSafeTool); ok { + return ps.ParallelSafe() + } + } + } + } + allowed := al.effectiveParallelSafeTools() + _, ok := allowed[tool] + return ok +} + +func (al *AgentLoop) effectiveParallelSafeTools() map[string]struct{} { + if al == nil || len(al.parallelSafeTools) == 0 { + m := make(map[string]struct{}, len(defaultParallelSafeToolNames)) + for _, name := range defaultParallelSafeToolNames { + m[strings.ToLower(strings.TrimSpace(name))] = struct{}{} + } + return m + } + return al.parallelSafeTools +} + +func (al *AgentLoop) maxToolParallelCalls() int { + if al == nil || al.maxParallelCalls <= 0 { + return toolLoopMaxParallelCalls + } + return al.maxParallelCalls +} + +func (al *AgentLoop) finalizeToolLoop(ctx context.Context, messages []providers.Message) (*providers.LLMResponse, error) { + finalizeMessages := append([]providers.Message{}, messages...) + finalizeMessages = append(finalizeMessages, providers.Message{ + Role: "user", + Content: "Now provide your final response to the user based on the completed tool results. Do not call any tools.", + }) + finalizeMessages = sanitizeMessagesForToolCalling(finalizeMessages) + + llmCtx, cancelLLM := context.WithTimeout(ctx, al.llmCallTimeout) + draftResp, err := al.callLLMWithModelFallback(llmCtx, finalizeMessages, nil, map[string]interface{}{ + "max_tokens": 1024, + "temperature": 0.3, + }) + cancelLLM() + if err != nil { + return nil, err + } + if draftResp == nil || strings.TrimSpace(draftResp.Content) == "" { + return draftResp, nil + } + // Gate polish by draft quality to avoid unnecessary extra LLM pass. + if !shouldRunFinalizePolish(draftResp.Content) { + return draftResp, nil + } + quality := al.assessFinalizeDraftQuality(ctx, draftResp.Content) + if quality >= finalizeQualityThreshold { + return draftResp, nil + } + + polished, perr := al.polishFinalResponse(ctx, draftResp.Content) + if perr != nil || strings.TrimSpace(polished) == "" { + return draftResp, nil + } + return &providers.LLMResponse{ + Content: polished, + Usage: draftResp.Usage, + }, nil +} + +func shouldRunFinalizePolish(draft string) bool { + text := strings.TrimSpace(draft) + if len(text) < finalizeDraftMinCharsForPolish { + return false + } + return containsAnySubstring(strings.ToLower(text), " - ", "1.", "2.", "next", "建议", "步骤", "\n") +} + +func (al *AgentLoop) assessFinalizeDraftQuality(ctx context.Context, draft string) float64 { + if strings.TrimSpace(draft) == "" { + return 0 + } + heuristic := localFinalizeDraftQualityScore(draft) + if heuristic >= finalizeHeuristicHighThreshold || heuristic <= finalizeHeuristicLowThreshold || al == nil { + return heuristic + } + + // Only call LLM when heuristic is uncertain; keep this call lightweight. + msgs := []providers.Message{ + { + Role: "user", + Content: `Evaluate draft answer quality for user delivery. +Return JSON only: +{"score":0.0} +Scoring: +- 1.0 = clear, concise, actionable, no repetition. +- 0.0 = unclear or noisy.`, + }, + { + Role: "user", + Content: draft, + }, + } + timeout := al.llmCallTimeout / 4 + if timeout < 2*time.Second { + timeout = 2 * time.Second + } + qctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + resp, err := al.callLLMWithModelFallback(qctx, msgs, nil, map[string]interface{}{ + "max_tokens": 28, + "temperature": 0.0, + }) + if err != nil || resp == nil { + return heuristic + } + raw := extractJSONObject(resp.Content) + if raw == "" { + return heuristic + } + var parsed finalizeQualityResponse + if err := json.Unmarshal([]byte(raw), &parsed); err != nil { + return heuristic + } + if parsed.Score < 0 { + return heuristic + } + if parsed.Score > 1 { + parsed.Score = 1 + } + // Blend heuristic and LLM score for stability while keeping calls cheap. + return clamp01(0.6*heuristic + 0.4*parsed.Score) +} + +func localFinalizeDraftQualityScore(draft string) float64 { + text := strings.TrimSpace(draft) + if text == "" { + return 0 + } + + score := 0.0 + length := len(text) + switch { + case length >= 420: + score += 0.36 + case length >= 260: + score += 0.30 + case length >= 140: + score += 0.22 + case length >= 80: + score += 0.14 + default: + score += 0.06 + } + + lower := strings.ToLower(text) + if containsAnySubstring(lower, "\n-", "\n1.", "\n2.", "next", "steps", "建议", "步骤", "下一步") { + score += 0.22 + } + if containsAnySubstring(lower, "because", "therefore", "原因", "结论", "result", "建议") { + score += 0.14 + } + if containsAnySubstring(lower, "error", "failed", "unknown", "todo", "tbd") { + score -= 0.08 + } + + lines := strings.Split(text, "\n") + if hasExcessiveDuplicateLines(lines) { + score -= 0.18 + } + return clamp01(score) +} + +func hasExcessiveDuplicateLines(lines []string) bool { + if len(lines) < 3 { + return false + } + seen := map[string]int{} + dup := 0 + total := 0 + for _, line := range lines { + n := strings.TrimSpace(strings.ToLower(line)) + if n == "" { + continue + } + total++ + seen[n]++ + if seen[n] > 1 { + dup++ + } + } + if total == 0 { + return false + } + return float64(dup)/float64(total) > 0.25 +} + +func clamp01(v float64) float64 { + if v < 0 { + return 0 + } + if v > 1 { + return 1 + } + return v +} + +func (al *AgentLoop) polishFinalResponse(ctx context.Context, draft string) (string, error) { + if al == nil || strings.TrimSpace(draft) == "" { + return draft, nil + } + // Phase-2 polish: keep facts, remove repetition, and present concise actionable output. + polishMessages := []providers.Message{ + { + Role: "system", + Content: al.withBootstrapPolicy(`Rewrite the draft answer for end users. +Rules: +- Keep factual meaning unchanged. +- Keep concise and actionable. +- Remove internal reasoning or repetitive text. +- Plain text only.`), + }, + { + Role: "user", + Content: draft, + }, + } + timeout := al.llmCallTimeout / 2 + if timeout < 3*time.Second { + timeout = 3 * time.Second + } + llmCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + resp, err := al.callLLMWithModelFallback(llmCtx, polishMessages, nil, map[string]interface{}{ + "max_tokens": 520, + "temperature": 0.2, + }) + if err != nil || resp == nil { + return "", err + } + return strings.TrimSpace(resp.Content), nil +} + +func (al *AgentLoop) reflectToolLoopProgress(ctx context.Context, messages []providers.Message) (decision string, reason string, confidence float64) { + if al == nil { + return "continue", "agent unavailable", 0 + } + local := inferLocalReflectionSignal(messages) + if !local.uncertain { + return local.decision, local.reason, local.confidence + } + + reflectMessages := append([]providers.Message{}, messages...) + reflectMessages = append(reflectMessages, providers.Message{ + Role: "user", + Content: `Classify current execution progress using JSON only. +Schema: +{"decision":"done|continue|blocked","reason":"short reason","confidence":0.0} +Rules: +- done: objective appears completed from current tool outputs. +- blocked: cannot make meaningful progress without new user input/permission/resource. +- continue: still actionable with additional non-repetitive tool calls. +- Keep reason <= 18 words.`, + }) + reflectMessages = sanitizeMessagesForToolCalling(reflectMessages) + + rctx, cancel := context.WithTimeout(ctx, toolLoopReflectTimeout) + defer cancel() + resp, err := al.callLLMWithModelFallback(rctx, reflectMessages, nil, map[string]interface{}{ + "max_tokens": 120, + "temperature": 0.0, + }) + if err != nil || resp == nil { + return local.decision, local.reason, local.confidence + } + raw := extractJSONObject(resp.Content) + if raw == "" { + return local.decision, local.reason, local.confidence + } + var parsed loopReflectResponse + if err := json.Unmarshal([]byte(raw), &parsed); err != nil { + return local.decision, local.reason, local.confidence + } + decision = normalizeReflectDecision(parsed.Decision) + reason = strings.TrimSpace(parsed.Reason) + if reason == "" { + reason = "insufficient signal" + } + confidence = parsed.Confidence + if confidence < 0 { + confidence = 0 + } + if confidence > 1 { + confidence = 1 + } + if decision == "done" && confidence < 0.55 { + return "continue", reason, confidence + } + if decision == "blocked" && confidence < 0.45 { + return "continue", reason, confidence + } + return decision, reason, confidence +} + +func inferLocalReflectionSignal(messages []providers.Message) localReflectSignal { + lastToolCount := 0 + errorCount := 0 + emptyCount := 0 + latestToolText := "" + + // Scan recent tool outputs only (windowed) for cheap deterministic signal. + for i := len(messages) - 1; i >= 0 && lastToolCount < 6; i-- { + if strings.TrimSpace(messages[i].Role) != "tool" { + continue + } + lastToolCount++ + text := strings.TrimSpace(messages[i].Content) + if latestToolText == "" { + latestToolText = strings.ToLower(text) + } + if text == "" { + emptyCount++ + continue + } + lower := strings.ToLower(text) + if strings.HasPrefix(lower, "error:") || containsAnySubstring(lower, "failed", "permission denied", "forbidden", "unauthorized", "not allowed") { + errorCount++ + } + } + if lastToolCount == 0 { + return localReflectSignal{ + decision: "continue", + reason: "insufficient local signal", + confidence: 0.40, + uncertain: true, + } + } + if errorCount >= 2 && errorCount == lastToolCount { + return localReflectSignal{ + decision: "blocked", + reason: "recent tool outputs are all errors", + confidence: 0.90, + uncertain: false, + } + } + if errorCount > 0 && containsAnySubstring(latestToolText, "permission denied", "forbidden", "unauthorized", "not allowed") { + return localReflectSignal{ + decision: "blocked", + reason: "permission failure detected", + confidence: 0.86, + uncertain: false, + } + } + if errorCount == 0 && emptyCount == 0 && containsAnySubstring(latestToolText, "completed", "success", "done", "ok") { + return localReflectSignal{ + decision: "done", + reason: "successful tool output indicates completion", + confidence: 0.80, + uncertain: false, + } + } + return localReflectSignal{ + decision: "continue", + reason: "mixed signals require model judgment", + confidence: 0.52, + uncertain: true, + } +} + +func (al *AgentLoop) runSelfRepairIfNeeded( + ctx context.Context, + sessionKey string, + userPrompt string, + baseMessages []providers.Message, + finalContent string, + progress *stageReporter, +) (string, int) { + current := strings.TrimSpace(finalContent) + if current == "" { + return finalContent, 0 + } + mem := selfRepairMemory{ + promptsUsed: make(map[string]struct{}), + outputsSeen: map[string]struct{}{ + repairOutputSignature(current): {}, + }, + } + repairPasses := 0 + for repairPasses < maxSelfRepairPasses { + needs, repairPrompt, confidence := al.shouldRunSelfRepair(ctx, userPrompt, current, mem) + if !needs || strings.TrimSpace(repairPrompt) == "" { + break + } + normalizedPrompt := normalizeRepairPrompt(repairPrompt) + if _, seen := mem.promptsUsed[normalizedPrompt]; seen { + mem.failureReason = append(mem.failureReason, "repeated prompt") + break + } + mem.promptsUsed[normalizedPrompt] = struct{}{} + repairPasses++ + if progress != nil { + progress.Publish(4, 5, "self-repair", fmt.Sprintf("Running self-repair pass %d (confidence %.2f).", repairPasses, confidence)) + } + repairMessages := append([]providers.Message{}, baseMessages...) + repairMessages = append(repairMessages, providers.Message{ + Role: "user", + Content: fmt.Sprintf("Self-repair pass: %s\nCurrent draft response:\n%s", + repairPrompt, + truncateString(current, 1200), + ), + }) + repaired, _, err := al.runLLMToolLoop(ctx, repairMessages, sessionKey, false, nil) + repaired = strings.TrimSpace(repaired) + if err != nil || repaired == "" { + mem.failureReason = append(mem.failureReason, "empty or failed repair run") + break + } + sig := repairOutputSignature(repaired) + if _, seen := mem.outputsSeen[sig]; seen { + mem.failureReason = append(mem.failureReason, "repeated output") + break + } + mem.outputsSeen[sig] = struct{}{} + current = repaired + } + return current, repairPasses +} + +func (al *AgentLoop) shouldRunSelfRepair(ctx context.Context, userPrompt string, finalContent string, mem selfRepairMemory) (needs bool, repairPrompt string, confidence float64) { + text := strings.TrimSpace(finalContent) + if text == "" { + return false, "", 0 + } + if needs, prompt := shouldForceSelfRepairHeuristic(strings.TrimSpace(userPrompt), text); needs { + if promptSeen(mem, prompt) { + return false, "", 0 + } + return true, prompt, 0.86 + } + + if al == nil { + return false, "", 0 + } + llmTimeout := al.llmCallTimeout / 4 + if llmTimeout < 2*time.Second { + llmTimeout = 2 * time.Second + } + rctx, cancel := context.WithTimeout(ctx, llmTimeout) + defer cancel() + resp, err := al.callLLMWithModelFallback(rctx, []providers.Message{ + { + Role: "user", + Content: `Judge whether the draft fully satisfies the user task. +Return JSON only: +{"needs_repair":true|false,"reason":"short reason","repair_prompt":"short actionable prompt","confidence":0.0}`, + }, + {Role: "user", Content: fmt.Sprintf("User task:\n%s\n\nDraft response:\n%s", userPrompt, truncateString(text, 1200))}, + }, nil, map[string]interface{}{ + "max_tokens": 96, + "temperature": 0.0, + }) + if err != nil || resp == nil { + return false, "", 0 + } + raw := extractJSONObject(resp.Content) + if raw == "" { + return false, "", 0 + } + var parsed selfRepairDecision + if err := json.Unmarshal([]byte(raw), &parsed); err != nil { + return false, "", 0 + } + if parsed.Confidence < 0 { + parsed.Confidence = 0 + } + if parsed.Confidence > 1 { + parsed.Confidence = 1 + } + if !parsed.NeedsRepair || parsed.Confidence < 0.62 { + return false, "", parsed.Confidence + } + prompt := strings.TrimSpace(parsed.RepairPrompt) + if prompt == "" { + prompt = strings.TrimSpace(parsed.Reason) + } + if prompt == "" { + prompt = "Address missing requirements and provide a complete, actionable final answer." + } + if promptSeen(mem, prompt) { + return false, "", parsed.Confidence + } + return true, prompt, parsed.Confidence +} + +func normalizeRepairPrompt(prompt string) string { + return strings.ToLower(strings.TrimSpace(prompt)) +} + +func promptSeen(mem selfRepairMemory, prompt string) bool { + if len(mem.promptsUsed) == 0 { + return false + } + _, ok := mem.promptsUsed[normalizeRepairPrompt(prompt)] + return ok +} + +func repairOutputSignature(content string) string { + text := strings.ToLower(strings.TrimSpace(content)) + if len(text) > 480 { + text = text[:480] + } + return text +} + +func shouldForceSelfRepairHeuristic(userPrompt string, finalContent string) (bool, string) { + prompt := strings.ToLower(strings.TrimSpace(userPrompt)) + resp := strings.ToLower(strings.TrimSpace(finalContent)) + if resp == "" { + return true, "Response is empty. Provide complete final answer." + } + if containsAnySubstring(resp, "i don't know", "cannot complete", "无法完成", "不知道", "todo", "tbd") { + return true, "Replace uncertainty with concrete diagnosis and next actionable steps." + } + if containsAnySubstring(prompt, "steps", "步骤", "plan", "方案", "how to", "如何") && + !containsAnySubstring(resp, "1.", "2.", "step", "步骤", "next", "下一步") { + return true, "Provide structured step-by-step answer aligned with user task." + } + return false, "" +} + +func normalizeReflectDecision(value string) string { + switch strings.ToLower(strings.TrimSpace(value)) { + case "done": + return "done" + case "blocked": + return "blocked" + default: + return "continue" + } +} + +func classifyToolExecutionError(err error, blockedAlready bool) (errType string, retryable bool, blockedLikely bool) { + blockedLikely = blockedAlready + if err == nil { + return "none", false, blockedLikely + } + if errors.Is(err, context.DeadlineExceeded) { + return "timeout", true, blockedLikely + } + msg := strings.ToLower(strings.TrimSpace(err.Error())) + switch { + case containsAnySubstring(msg, "forbidden", "permission", "denied", "unauthorized", "not allowed"): + return "permission", false, true + case containsAnySubstring(msg, "missing", "required", "invalid argument", "bad argument", "invalid parameter", "parse"): + return "invalid_input", false, true + case containsAnySubstring(msg, "not found", "no such file", "does not exist"): + return "not_found", false, blockedLikely + case containsAnySubstring(msg, "timeout", "temporary", "temporarily", "unavailable", "connection reset", "connection refused", "rate limit", "429", "502", "503", "504"): + return "transient", true, blockedLikely + default: + return "unknown", false, blockedLikely + } +} + +func summarizeToolActOutcome(outcome toolActOutcome) string { + if outcome.executedCalls == 0 { + return "" + } + records, truncatedCount := compactToolExecutionRecords(outcome.records, toolSummaryMaxRecords) + errorTypeCount := map[string]int{} + for _, r := range outcome.records { + if r.Status != "error" { + continue + } + key := strings.TrimSpace(r.ErrorType) + if key == "" { + key = "unknown" + } + errorTypeCount[key]++ + } + summary := map[string]interface{}{ + "executed_calls": outcome.executedCalls, + "dropped_calls": outcome.droppedCalls, + "errors": outcome.roundToolErrors, + "retryable_errors": outcome.retryableErrors, + "hard_errors": outcome.hardErrors, + "empty_results": outcome.emptyResults, + "blocked_likely": outcome.blockedLikely, + "truncated_by_budget": outcome.truncated, + "records": records, + "records_truncated": truncatedCount, + "error_type_count": errorTypeCount, + } + data, err := json.Marshal(summary) + if err != nil { + return "" + } + return string(data) +} + +func compactToolExecutionRecords(records []toolExecutionRecord, max int) ([]toolExecutionRecord, int) { + if len(records) == 0 || max <= 0 || len(records) <= max { + return records, 0 + } + selected := make([]toolExecutionRecord, 0, max) + used := make(map[int]struct{}, max) + + // Keep all errors first. + for i, r := range records { + if len(selected) >= max { + break + } + if r.Status == "error" { + selected = append(selected, r) + used[i] = struct{}{} + } + } + // Keep one early non-error exemplar. + for i, r := range records { + if len(selected) >= max { + break + } + if _, ok := used[i]; ok { + continue + } + if r.Status != "error" { + selected = append(selected, r) + used[i] = struct{}{} + break + } + } + // Keep one latest non-error exemplar from tail. + for i := len(records) - 1; i >= 0; i-- { + if len(selected) >= max { + break + } + if _, ok := used[i]; ok { + continue + } + r := records[i] + if r.Status != "error" { + selected = append(selected, r) + used[i] = struct{}{} + break + } + } + // Fill remaining slots by original order. + for i, r := range records { + if len(selected) >= max { + break + } + if _, ok := used[i]; ok { + continue + } + selected = append(selected, r) + used[i] = struct{}{} + } + + truncated := len(records) - len(selected) + if truncated < 0 { + truncated = 0 + } + return selected, truncated +} + +func (al *AgentLoop) computeToolLoopBudget(state toolLoopState) toolLoopBudget { + b := toolLoopBudget{ + maxCallsPerIteration: toolLoopMaxCallsPerIteration, + singleCallTimeout: toolLoopSingleCallTimeout, + maxActDuration: toolLoopMaxActDuration, + } + // Early rounds can be slightly wider if no degradation signals. + if state.iteration <= 1 && state.consecutiveAllToolErrorRounds == 0 && state.repeatedToolCallRounds == 0 { + b.maxCallsPerIteration = toolLoopMaxCallsPerIteration + 1 + b.maxActDuration = toolLoopMaxActDuration + 10*time.Second + } + // If failures accumulate, tighten budget to force quicker convergence. + if state.consecutiveAllToolErrorRounds > 0 || state.repeatedToolCallRounds > 0 { + b.maxCallsPerIteration = toolLoopMaxCallsPerIteration - 2 + b.singleCallTimeout = toolLoopSingleCallTimeout - 6*time.Second + b.maxActDuration = toolLoopMaxActDuration - 15*time.Second + } + // Couple reflection confidence to next-round budget when reflection is recent. + if state.lastReflectIteration > 0 && state.iteration-state.lastReflectIteration <= 1 { + switch state.lastReflectDecision { + case "blocked": + b.maxCallsPerIteration = toolLoopMinCallsPerIteration + b.singleCallTimeout = toolLoopMinSingleCallTimeout + b.maxActDuration = toolLoopMinActDuration + case "continue": + if state.lastReflectConfidence < 0.6 { + b.maxCallsPerIteration -= 1 + b.singleCallTimeout -= 3 * time.Second + b.maxActDuration -= 8 * time.Second + } else if state.lastReflectConfidence >= 0.85 && + state.consecutiveAllToolErrorRounds == 0 && + state.repeatedToolCallRounds == 0 { + b.maxCallsPerIteration += 1 + b.singleCallTimeout += 2 * time.Second + b.maxActDuration += 6 * time.Second + } + } + } + // Near max iterations, force conservative execution. + if al != nil && state.iteration >= al.maxIterations-1 { + b.maxCallsPerIteration = toolLoopMinCallsPerIteration + b.singleCallTimeout = toolLoopMinSingleCallTimeout + b.maxActDuration = toolLoopMinActDuration + } + if b.maxCallsPerIteration > toolLoopMaxCallsPerIteration+2 { + b.maxCallsPerIteration = toolLoopMaxCallsPerIteration + 2 + } + if b.singleCallTimeout > toolLoopSingleCallTimeout+5*time.Second { + b.singleCallTimeout = toolLoopSingleCallTimeout + 5*time.Second + } + if b.maxActDuration > toolLoopMaxActDuration+12*time.Second { + b.maxActDuration = toolLoopMaxActDuration + 12*time.Second + } + if b.maxCallsPerIteration < toolLoopMinCallsPerIteration { + b.maxCallsPerIteration = toolLoopMinCallsPerIteration + } + if b.singleCallTimeout < toolLoopMinSingleCallTimeout { + b.singleCallTimeout = toolLoopMinSingleCallTimeout + } + if b.maxActDuration < toolLoopMinActDuration { + b.maxActDuration = toolLoopMinActDuration + } + return b +} + +func shouldPersistToolResultRecord(record toolExecutionRecord, index int, total int) bool { + if total <= 0 || index < 0 || index >= total { + return false + } + if record.Status == "error" || record.Status == "empty" { + return true + } + return index == 0 || index == total-1 +} + +func (al *AgentLoop) shouldTriggerReflection(state toolLoopState, outcome toolActOutcome) bool { + forceTrigger := false + if outcome.roundToolErrors > 0 { + forceTrigger = true + } + if outcome.hardErrors > 0 { + forceTrigger = true + } + if state.repeatedToolCallRounds > 0 { + forceTrigger = true + } + if al != nil && state.iteration >= al.maxIterations-1 { + forceTrigger = true + } + if outcome.executedCalls > 0 && (strings.TrimSpace(outcome.lastToolResult) == "" || outcome.emptyResults > 0) { + forceTrigger = true + } + if forceTrigger { + return true + } + + // Cooldown: avoid reflection too frequently when no hard risk signals. + if state.lastReflectIteration > 0 && state.iteration-state.lastReflectIteration < reflectionCooldownRounds { + return false + } + // Soft trigger: periodically check progress only when there is meaningful activity. + return outcome.executedCalls > 0 } // sanitizeMessagesForToolCalling removes orphan tool-calling turns so provider-side @@ -2224,6 +4175,18 @@ func sanitizeMessagesForToolCalling(messages []providers.Message) []providers.Me return out } +func toolCallsSignature(calls []providers.ToolCall) string { + if len(calls) == 0 { + return "" + } + parts := make([]string, 0, len(calls)) + for _, tc := range calls { + argsJSON, _ := json.Marshal(tc.Arguments) + parts = append(parts, fmt.Sprintf("%s:%s", strings.TrimSpace(tc.Name), string(argsJSON))) + } + return strings.Join(parts, "|") +} + // truncate returns a truncated version of s with at most maxLen characters. // If the string is truncated, "..." is appended to indicate truncation. // If the string fits within maxLen, it is returned unchanged. @@ -3233,10 +5196,53 @@ func (al *AgentLoop) handleSlashCommand(ctx context.Context, msg bus.InboundMess switch fields[0] { case "/help": - return true, "Slash commands:\n/help\n/status\n/run [--stage-report]\n/config get \n/config set \n/reload\n/pipeline list\n/pipeline status \n/pipeline ready ", nil + return true, "Slash commands:\n/help\n/status\n/status run [run_id|latest]\n/status wait [timeout_seconds]\n/run [--stage-report]\n/config get \n/config set \n/reload\n/pipeline list\n/pipeline status \n/pipeline ready ", nil case "/stop": return true, "Stop command is handled by queue runtime. Send /stop from your channel session to interrupt current response.", nil case "/status": + if len(fields) >= 2 { + switch fields[1] { + case "run": + intent := runControlIntent{timeout: defaultRunWaitTimeout} + if len(fields) >= 3 { + target := strings.TrimSpace(fields[2]) + if strings.EqualFold(target, "latest") { + intent.latest = true + } else { + intent.runID = target + } + } else { + intent.latest = true + } + return true, al.executeRunControlIntent(ctx, msg.SessionKey, intent), nil + case "wait": + if len(fields) < 3 { + return true, "Usage: /status wait [timeout_seconds]", nil + } + intent := runControlIntent{ + wait: true, + timeout: defaultRunWaitTimeout, + } + target := strings.TrimSpace(fields[2]) + if strings.EqualFold(target, "latest") { + intent.latest = true + } else { + intent.runID = target + } + if len(fields) >= 4 { + timeoutSec, parseErr := strconv.Atoi(strings.TrimSpace(fields[3])) + if parseErr != nil || timeoutSec <= 0 { + return true, "Usage: /status wait [timeout_seconds]", nil + } + intent.timeout = time.Duration(timeoutSec) * time.Second + if intent.timeout > maxRunWaitTimeout { + intent.timeout = maxRunWaitTimeout + } + } + return true, al.executeRunControlIntent(ctx, msg.SessionKey, intent), nil + } + } + cfg, err := config.LoadConfig(al.getConfigPathForCommands()) if err != nil { return true, "", fmt.Errorf("status failed: %w", err) @@ -3251,14 +5257,15 @@ func (al *AgentLoop) handleSlashCommand(ctx context.Context, msg bus.InboundMess activeBase = p.APIBase } } - return true, fmt.Sprintf("Model: %s\nProxy: %s\nAPI Base: %s\nResponses Compact: %v\nLogging: %v\nConfig: %s", + statusText := fmt.Sprintf("Model: %s\nProxy: %s\nAPI Base: %s\nResponses Compact: %v\nLogging: %v\nConfig: %s", al.model, activeProxy, activeBase, providers.ProviderSupportsResponsesCompact(cfg, activeProxy), cfg.Logging.Enabled, al.getConfigPathForCommands(), - ), nil + ) + return true, statusText, nil case "/reload": running, err := al.triggerGatewayReloadFromAgent() if err != nil { diff --git a/pkg/agent/loop_replay_baseline_test.go b/pkg/agent/loop_replay_baseline_test.go new file mode 100644 index 0000000..5eb1076 --- /dev/null +++ b/pkg/agent/loop_replay_baseline_test.go @@ -0,0 +1,242 @@ +package agent + +import ( + "context" + "fmt" + "strings" + "sync" + "testing" + "time" + + "clawgo/pkg/providers" + "clawgo/pkg/session" + "clawgo/pkg/tools" +) + +type replayScenario string + +const ( + replayDirectSuccess replayScenario = "direct_success" + replayOneToolSuccess replayScenario = "one_tool_success" + replayRepeatedToolCall replayScenario = "repeated_tool_call" + replayTransientFailure replayScenario = "transient_failure" + replayPermissionBlock replayScenario = "permission_block" +) + +type replayProvider struct { + mu sync.Mutex + scenario replayScenario + planCalls int + reflectCalls int + finalizeCalls int + polishCalls int + totalCalls int +} + +func (p *replayProvider) Chat(ctx context.Context, messages []providers.Message, defs []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + p.mu.Lock() + defer p.mu.Unlock() + p.totalCalls++ + + last := "" + if len(messages) > 0 { + last = strings.TrimSpace(messages[len(messages)-1].Content) + } + + // Phase-2 polish call. + if len(defs) == 0 && len(messages) > 0 && strings.Contains(strings.ToLower(strings.TrimSpace(messages[0].Content)), "rewrite the draft answer for end users") { + p.polishCalls++ + return &providers.LLMResponse{Content: "polished final response"}, nil + } + + // Reflection call. + if len(defs) == 0 && strings.Contains(last, "Classify current execution progress using JSON only.") { + p.reflectCalls++ + switch p.scenario { + case replayTransientFailure: + return &providers.LLMResponse{Content: `{"decision":"continue","reason":"transient failures may recover","confidence":0.74}`}, nil + case replayRepeatedToolCall: + return &providers.LLMResponse{Content: `{"decision":"continue","reason":"need another attempt","confidence":0.72}`}, nil + default: + return &providers.LLMResponse{Content: `{"decision":"continue","reason":"default continue","confidence":0.60}`}, nil + } + } + + // Finalization draft call. + if len(defs) == 0 { + p.finalizeCalls++ + return &providers.LLMResponse{Content: "draft final response"}, nil + } + + // Planning/tool-loop calls. + p.planCalls++ + switch p.scenario { + case replayDirectSuccess: + return &providers.LLMResponse{Content: "direct completed"}, nil + case replayOneToolSuccess: + if p.planCalls == 1 { + return &providers.LLMResponse{ + Content: "call one tool", + ToolCalls: []providers.ToolCall{ + {ID: "tc-1", Name: "ok_tool", Arguments: map[string]interface{}{"x": 1}}, + }, + }, nil + } + return &providers.LLMResponse{Content: "task completed after one tool"}, nil + case replayRepeatedToolCall: + return &providers.LLMResponse{ + Content: "repeating tool", + ToolCalls: []providers.ToolCall{ + {ID: fmt.Sprintf("tc-r-%d", p.planCalls), Name: "ok_tool", Arguments: map[string]interface{}{"same": true}}, + }, + }, nil + case replayTransientFailure: + return &providers.LLMResponse{ + Content: "transient fail tool", + ToolCalls: []providers.ToolCall{ + {ID: fmt.Sprintf("tc-t-%d", p.planCalls), Name: "fail_tool_transient", Arguments: map[string]interface{}{}}, + }, + }, nil + case replayPermissionBlock: + return &providers.LLMResponse{ + Content: "permission fail tool", + ToolCalls: []providers.ToolCall{ + {ID: "tc-p-1", Name: "fail_tool_permission", Arguments: map[string]interface{}{}}, + }, + }, nil + default: + return &providers.LLMResponse{Content: "unexpected scenario"}, nil + } +} + +func (p *replayProvider) GetDefaultModel() string { return "test-model" } + +type replayToolImpl struct { + name string + run func(context.Context, map[string]interface{}) (string, error) +} + +func (t replayToolImpl) Name() string { return t.name } +func (t replayToolImpl) Description() string { return "replay tool" } +func (t replayToolImpl) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + } +} +func (t replayToolImpl) Execute(ctx context.Context, args map[string]interface{}) (string, error) { + return t.run(ctx, args) +} + +type replayCaseResult struct { + name replayScenario + ok bool + iterations int + llmCalls int + reflectCalls int +} + +func TestAgentLoopReplayBaseline(t *testing.T) { + t.Parallel() + + scenarios := []replayScenario{ + replayDirectSuccess, + replayOneToolSuccess, + replayRepeatedToolCall, + replayTransientFailure, + replayPermissionBlock, + } + + results := make([]replayCaseResult, 0, len(scenarios)) + for _, sc := range scenarios { + sc := sc + t.Run(string(sc), func(t *testing.T) { + reg := tools.NewToolRegistry() + reg.Register(replayToolImpl{ + name: "ok_tool", + run: func(ctx context.Context, args map[string]interface{}) (string, error) { + return "ok", nil + }, + }) + reg.Register(replayToolImpl{ + name: "fail_tool_transient", + run: func(ctx context.Context, args map[string]interface{}) (string, error) { + return "", fmt.Errorf("temporary unavailable 503") + }, + }) + reg.Register(replayToolImpl{ + name: "fail_tool_permission", + run: func(ctx context.Context, args map[string]interface{}) (string, error) { + return "", fmt.Errorf("permission denied") + }, + }) + + provider := &replayProvider{scenario: sc} + al := &AgentLoop{ + provider: provider, + providersByProxy: map[string]providers.LLMProvider{"proxy": provider}, + modelsByProxy: map[string][]string{"proxy": {"test-model"}}, + proxy: "proxy", + model: "test-model", + maxIterations: 6, + llmCallTimeout: 3 * time.Second, + tools: reg, + sessions: session.NewSessionManager(""), + workspace: t.TempDir(), + } + + msgs := []providers.Message{ + {Role: "system", Content: "you are a test agent"}, + {Role: "user", Content: "complete task"}, + } + + out, iterations, err := al.runLLMToolLoop(context.Background(), msgs, "replay:"+string(sc), false, nil) + if err != nil { + t.Fatalf("runLLMToolLoop error: %v", err) + } + if strings.TrimSpace(out) == "" { + t.Fatalf("empty output") + } + results = append(results, replayCaseResult{ + name: sc, + ok: true, + iterations: iterations, + llmCalls: provider.totalCalls, + reflectCalls: provider.reflectCalls, + }) + }) + } + + total := len(results) + if total != len(scenarios) { + t.Fatalf("unexpected results count: %d", total) + } + success := 0 + iterSum := 0 + llmSum := 0 + reflectSum := 0 + for _, r := range results { + if r.ok { + success++ + } + iterSum += r.iterations + llmSum += r.llmCalls + reflectSum += r.reflectCalls + } + successRate := float64(success) / float64(total) + avgIter := float64(iterSum) / float64(total) + avgLLM := float64(llmSum) / float64(total) + avgReflect := float64(reflectSum) / float64(total) + + t.Logf("Replay baseline: success_rate=%.2f avg_iterations=%.2f avg_llm_calls=%.2f avg_reflect_calls=%.2f", successRate, avgIter, avgLLM, avgReflect) + + if successRate < 1.0 { + t.Fatalf("expected all scenarios to succeed, got success_rate=%.2f", successRate) + } + if avgIter > 3.6 { + t.Fatalf("avg_iterations too high: %.2f", avgIter) + } + if avgLLM > 6.0 { + t.Fatalf("avg_llm_calls too high: %.2f", avgLLM) + } +} diff --git a/pkg/agent/loop_run_control_test.go b/pkg/agent/loop_run_control_test.go new file mode 100644 index 0000000..2cda8ac --- /dev/null +++ b/pkg/agent/loop_run_control_test.go @@ -0,0 +1,192 @@ +package agent + +import ( + "context" + "testing" + "time" + + "clawgo/pkg/bus" +) + +func TestDetectRunControlIntent(t *testing.T) { + t.Parallel() + + intent, ok := detectRunControlIntent("请等待 run-123-7 120 秒后告诉我状态") + if !ok { + t.Fatalf("expected run control intent") + } + if intent.runID != "run-123-7" { + t.Fatalf("unexpected run id: %s", intent.runID) + } + if !intent.wait { + t.Fatalf("expected wait=true") + } + if intent.timeout != 120*time.Second { + t.Fatalf("unexpected timeout: %s", intent.timeout) + } +} + +func TestDetectRunControlIntentLatest(t *testing.T) { + t.Parallel() + + intent, ok := detectRunControlIntent("latest run status") + if !ok { + t.Fatalf("expected latest run status intent") + } + if !intent.latest { + t.Fatalf("expected latest=true") + } + if intent.runID != "" { + t.Fatalf("expected empty run id") + } +} + +func TestParseRunWaitTimeout_MinClamp(t *testing.T) { + t.Parallel() + + got := parseRunWaitTimeout("wait run-1-1 1 s") + if got != minRunWaitTimeout { + t.Fatalf("expected min timeout %s, got %s", minRunWaitTimeout, got) + } +} + +func TestParseRunWaitTimeout_MinuteUnit(t *testing.T) { + t.Parallel() + + got := parseRunWaitTimeout("等待 run-1-1 2 分钟") + if got != 2*time.Minute { + t.Fatalf("expected 2m, got %s", got) + } +} + +func TestDetectRunControlIntentIgnoresNonControlText(t *testing.T) { + t.Parallel() + + if _, ok := detectRunControlIntent("帮我写一个README"); ok { + t.Fatalf("did not expect run control intent") + } +} + +func TestDetectRunControlIntentWithCustomLexicon(t *testing.T) { + t.Parallel() + + lex := runControlLexicon{ + latestKeywords: []string{"newest"}, + waitKeywords: []string{"block"}, + statusKeywords: []string{"health"}, + runMentionKeywords: []string{"job"}, + minuteUnits: map[string]struct{}{"mins": {}}, + } + + intent, ok := detectRunControlIntentWithLexicon("block run-55-1 for 2 mins and show health", lex) + if !ok { + t.Fatalf("expected intent with custom lexicon") + } + if !intent.wait { + t.Fatalf("expected wait=true") + } + if intent.timeout != 2*time.Minute { + t.Fatalf("unexpected timeout: %s", intent.timeout) + } +} + +func TestLatestRunStateBySession(t *testing.T) { + t.Parallel() + + now := time.Now() + al := &AgentLoop{ + runStates: map[string]*runState{ + "run-1-1": { + runID: "run-1-1", + sessionKey: "s1", + startedAt: now.Add(-2 * time.Minute), + }, + "run-1-2": { + runID: "run-1-2", + sessionKey: "s1", + startedAt: now.Add(-time.Minute), + }, + "run-2-1": { + runID: "run-2-1", + sessionKey: "s2", + startedAt: now, + }, + }, + } + + rs, ok := al.latestRunState("s1") + if !ok { + t.Fatalf("expected state for s1") + } + if rs.runID != "run-1-2" { + t.Fatalf("unexpected run id: %s", rs.runID) + } +} + +func TestHandleSlashCommand_StatusRunLatest(t *testing.T) { + t.Parallel() + + al := &AgentLoop{ + runStates: map[string]*runState{ + "run-100-1": { + runID: "run-100-1", + sessionKey: "s1", + status: runStatusOK, + acceptedAt: time.Now().Add(-time.Minute), + startedAt: time.Now().Add(-time.Minute), + endedAt: time.Now().Add(-30 * time.Second), + done: closedChan(), + }, + }, + } + handled, out, err := al.handleSlashCommand(context.Background(), bus.InboundMessage{ + Content: "/status run latest", + SessionKey: "s1", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !handled { + t.Fatalf("expected command handled") + } + if out == "" || !containsAnySubstring(out, "run-100-1", "Run ID: run-100-1") { + t.Fatalf("unexpected output: %s", out) + } +} + +func TestHandleSlashCommand_StatusWaitDoneRun(t *testing.T) { + t.Parallel() + + al := &AgentLoop{ + runStates: map[string]*runState{ + "run-200-2": { + runID: "run-200-2", + sessionKey: "s1", + status: runStatusOK, + acceptedAt: time.Now().Add(-time.Minute), + startedAt: time.Now().Add(-time.Minute), + endedAt: time.Now().Add(-20 * time.Second), + done: closedChan(), + }, + }, + } + handled, out, err := al.handleSlashCommand(context.Background(), bus.InboundMessage{ + Content: "/status wait run-200-2 3", + SessionKey: "s1", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !handled { + t.Fatalf("expected command handled") + } + if out == "" || !containsAnySubstring(out, "run-200-2", "Run ID: run-200-2") { + t.Fatalf("unexpected output: %s", out) + } +} + +func closedChan() chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch +} diff --git a/pkg/agent/loop_toolloop_test.go b/pkg/agent/loop_toolloop_test.go new file mode 100644 index 0000000..6292750 --- /dev/null +++ b/pkg/agent/loop_toolloop_test.go @@ -0,0 +1,617 @@ +package agent + +import ( + "context" + "fmt" + "strings" + "sync/atomic" + "testing" + "time" + + "clawgo/pkg/config" + "clawgo/pkg/providers" + "clawgo/pkg/session" + "clawgo/pkg/tools" +) + +func TestToolCallsSignatureStableForSameInput(t *testing.T) { + t.Parallel() + + calls := []providers.ToolCall{ + { + Name: "shell", + Arguments: map[string]interface{}{"cmd": "ls -la", "cwd": "/tmp"}, + }, + { + Name: "read_file", + Arguments: map[string]interface{}{"path": "README.md"}, + }, + } + + s1 := toolCallsSignature(calls) + s2 := toolCallsSignature(calls) + if s1 == "" { + t.Fatalf("expected non-empty signature") + } + if s1 != s2 { + t.Fatalf("expected stable signature, got %q vs %q", s1, s2) + } +} + +func TestToolCallsSignatureDiffersByArguments(t *testing.T) { + t.Parallel() + + callsA := []providers.ToolCall{ + {Name: "shell", Arguments: map[string]interface{}{"cmd": "ls -la"}}, + } + callsB := []providers.ToolCall{ + {Name: "shell", Arguments: map[string]interface{}{"cmd": "pwd"}}, + } + + if toolCallsSignature(callsA) == toolCallsSignature(callsB) { + t.Fatalf("expected different signatures for different arguments") + } +} + +func TestNormalizeReflectDecision(t *testing.T) { + t.Parallel() + + if got := normalizeReflectDecision("DONE"); got != "done" { + t.Fatalf("expected done, got %s", got) + } + if got := normalizeReflectDecision("blocked"); got != "blocked" { + t.Fatalf("expected blocked, got %s", got) + } + if got := normalizeReflectDecision("unknown"); got != "continue" { + t.Fatalf("expected continue, got %s", got) + } +} + +func TestShouldTriggerReflectionReplayScenarios(t *testing.T) { + t.Parallel() + + al := &AgentLoop{maxIterations: 5} + tests := []struct { + name string + state toolLoopState + outcome toolActOutcome + want bool + }{ + { + name: "tool failure", + state: toolLoopState{iteration: 2}, + outcome: toolActOutcome{executedCalls: 2, roundToolErrors: 1, lastToolResult: "Error: denied"}, + want: true, + }, + { + name: "repetition hint", + state: toolLoopState{iteration: 2, repeatedToolCallRounds: 1}, + outcome: toolActOutcome{executedCalls: 1, lastToolResult: "ok"}, + want: true, + }, + { + name: "near iteration limit", + state: toolLoopState{iteration: 4}, + outcome: toolActOutcome{executedCalls: 1, lastToolResult: "ok"}, + want: true, + }, + { + name: "empty tool result", + state: toolLoopState{iteration: 1}, + outcome: toolActOutcome{executedCalls: 1, lastToolResult: ""}, + want: true, + }, + { + name: "healthy progress", + state: toolLoopState{iteration: 1}, + outcome: toolActOutcome{executedCalls: 1, lastToolResult: "done step 1"}, + want: true, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + got := al.shouldTriggerReflection(tt.state, tt.outcome) + if got != tt.want { + t.Fatalf("shouldTriggerReflection=%v want=%v", got, tt.want) + } + }) + } +} + +func TestShouldTriggerReflectionCooldown(t *testing.T) { + t.Parallel() + + al := &AgentLoop{maxIterations: 10} + state := toolLoopState{ + iteration: 3, + lastReflectIteration: 2, + } + // No hard trigger, within cooldown window -> false. + if al.shouldTriggerReflection(state, toolActOutcome{executedCalls: 1, lastToolResult: "ok"}) { + t.Fatalf("expected reflection suppressed by cooldown") + } + + // Hard trigger bypasses cooldown. + if !al.shouldTriggerReflection(state, toolActOutcome{executedCalls: 1, roundToolErrors: 1, lastToolResult: "Error: x"}) { + t.Fatalf("expected hard trigger to bypass cooldown") + } +} + +type replayTool struct { + name string + parallelSafe *bool + resourceKeys func(args map[string]interface{}) []string + run func(context.Context, map[string]interface{}) (string, error) +} + +func (t replayTool) Name() string { return t.name } +func (t replayTool) Description() string { return "replay tool" } +func (t replayTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + } +} +func (t replayTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { + if t.run != nil { + return t.run(ctx, args) + } + return fmt.Sprintf("ok:%s", t.name), nil +} + +func (t replayTool) ParallelSafe() bool { + if t.parallelSafe == nil { + return false + } + return *t.parallelSafe +} + +func (t replayTool) ResourceKeys(args map[string]interface{}) []string { + if t.resourceKeys == nil { + return nil + } + return t.resourceKeys(args) +} + +func TestActToolCalls_BudgetTruncationReplay(t *testing.T) { + t.Parallel() + + reg := tools.NewToolRegistry() + calls := make([]providers.ToolCall, 0, toolLoopMaxCallsPerIteration+2) + for i := 0; i < toolLoopMaxCallsPerIteration+2; i++ { + name := fmt.Sprintf("tool_%d", i) + reg.Register(replayTool{name: name}) + calls = append(calls, providers.ToolCall{ + ID: fmt.Sprintf("tc-%d", i), + Name: name, + Arguments: map[string]interface{}{}, + }) + } + + al := &AgentLoop{ + tools: reg, + sessions: session.NewSessionManager(""), + } + msgs := []providers.Message{} + out := al.actToolCalls(context.Background(), "", calls, &msgs, "s1", 1, toolLoopBudget{}, false, nil) + + if !out.truncated { + t.Fatalf("expected truncation due to budget") + } + if out.executedCalls != toolLoopMaxCallsPerIteration { + t.Fatalf("executed=%d want=%d", out.executedCalls, toolLoopMaxCallsPerIteration) + } + if out.droppedCalls != 2 { + t.Fatalf("dropped=%d want=2", out.droppedCalls) + } +} + +func TestComputeToolLoopBudget(t *testing.T) { + t.Parallel() + + al := &AgentLoop{maxIterations: 6} + + early := al.computeToolLoopBudget(toolLoopState{iteration: 1}) + if early.maxCallsPerIteration <= toolLoopMaxCallsPerIteration { + t.Fatalf("expected wider early budget, got %d", early.maxCallsPerIteration) + } + + degraded := al.computeToolLoopBudget(toolLoopState{iteration: 2, consecutiveAllToolErrorRounds: 1}) + if degraded.maxCallsPerIteration >= toolLoopMaxCallsPerIteration { + t.Fatalf("expected tighter degraded budget, got %d", degraded.maxCallsPerIteration) + } + + nearLimit := al.computeToolLoopBudget(toolLoopState{iteration: 5}) + if nearLimit.maxCallsPerIteration != toolLoopMinCallsPerIteration { + t.Fatalf("expected minimal near-limit calls, got %d", nearLimit.maxCallsPerIteration) + } + if nearLimit.singleCallTimeout != toolLoopMinSingleCallTimeout { + t.Fatalf("expected minimal near-limit timeout, got %s", nearLimit.singleCallTimeout) + } + + lowConfContinue := al.computeToolLoopBudget(toolLoopState{ + iteration: 2, + lastReflectDecision: "continue", + lastReflectConfidence: 0.42, + lastReflectIteration: 1, + }) + if lowConfContinue.maxCallsPerIteration >= toolLoopMaxCallsPerIteration { + t.Fatalf("expected low-confidence continue to tighten calls, got %d", lowConfContinue.maxCallsPerIteration) + } + + highConfContinue := al.computeToolLoopBudget(toolLoopState{ + iteration: 2, + lastReflectDecision: "continue", + lastReflectConfidence: 0.91, + lastReflectIteration: 1, + }) + if highConfContinue.maxCallsPerIteration <= toolLoopMaxCallsPerIteration { + t.Fatalf("expected high-confidence continue to widen calls, got %d", highConfContinue.maxCallsPerIteration) + } + + blocked := al.computeToolLoopBudget(toolLoopState{ + iteration: 2, + lastReflectDecision: "blocked", + lastReflectConfidence: 0.8, + lastReflectIteration: 1, + }) + if blocked.maxCallsPerIteration != toolLoopMinCallsPerIteration { + t.Fatalf("expected blocked reflection to force min calls, got %d", blocked.maxCallsPerIteration) + } +} + +func TestParallelSafeToolDeclarationOverridesWhitelist(t *testing.T) { + t.Parallel() + + yes := true + no := false + reg := tools.NewToolRegistry() + reg.Register(replayTool{name: "read_file", parallelSafe: &no}) + reg.Register(replayTool{name: "custom_safe", parallelSafe: &yes}) + + al := &AgentLoop{ + tools: reg, + parallelSafeTools: map[string]struct{}{ + "read_file": {}, + }, + } + + if al.isParallelSafeTool("read_file") { + t.Fatalf("tool declaration should override whitelist to false") + } + if !al.isParallelSafeTool("custom_safe") { + t.Fatalf("tool declaration true should be respected") + } +} + +func TestClassifyToolExecutionError(t *testing.T) { + t.Parallel() + + typ, retryable, blocked := classifyToolExecutionError(fmt.Errorf("permission denied to write file"), false) + if typ != "permission" || retryable || !blocked { + t.Fatalf("unexpected permission classification: %s %v %v", typ, retryable, blocked) + } + + typ, retryable, blocked = classifyToolExecutionError(fmt.Errorf("temporary unavailable 503"), false) + if typ != "transient" || !retryable || blocked { + t.Fatalf("unexpected transient classification: %s %v %v", typ, retryable, blocked) + } +} + +func TestSummarizeToolActOutcome(t *testing.T) { + t.Parallel() + + out := summarizeToolActOutcome(toolActOutcome{ + executedCalls: 1, + records: []toolExecutionRecord{ + {Tool: "shell", Status: "error", ErrorType: "permission", Retryable: false}, + }, + hardErrors: 1, + blockedLikely: true, + }) + if out == "" || !strings.Contains(out, "\"blocked_likely\":true") { + t.Fatalf("unexpected summary: %s", out) + } + if !strings.Contains(out, "\"error_type\":\"permission\"") { + t.Fatalf("missing record fields in summary: %s", out) + } + if !strings.Contains(out, "\"records_truncated\":0") { + t.Fatalf("expected records_truncated field, got: %s", out) + } +} + +func TestShouldPersistToolResultRecord(t *testing.T) { + t.Parallel() + + if !shouldPersistToolResultRecord(toolExecutionRecord{Status: "ok"}, 0, 3) { + t.Fatalf("first tool result should persist") + } + if !shouldPersistToolResultRecord(toolExecutionRecord{Status: "ok"}, 2, 3) { + t.Fatalf("last tool result should persist") + } + if shouldPersistToolResultRecord(toolExecutionRecord{Status: "ok"}, 1, 3) { + t.Fatalf("middle successful tool result should be skipped") + } + if !shouldPersistToolResultRecord(toolExecutionRecord{Status: "error"}, 1, 3) { + t.Fatalf("error tool result should persist") + } +} + +func TestCompactToolExecutionRecords(t *testing.T) { + t.Parallel() + + records := []toolExecutionRecord{ + {Tool: "a", Status: "ok"}, + {Tool: "b", Status: "error", ErrorType: "permission"}, + {Tool: "c", Status: "ok"}, + {Tool: "d", Status: "error", ErrorType: "transient"}, + {Tool: "e", Status: "ok"}, + {Tool: "f", Status: "ok"}, + } + out, truncated := compactToolExecutionRecords(records, 4) + if len(out) != 4 { + t.Fatalf("expected compact len 4, got %d", len(out)) + } + if truncated != 2 { + t.Fatalf("expected truncated 2, got %d", truncated) + } + foundErr := 0 + for _, r := range out { + if r.Status == "error" { + foundErr++ + } + } + if foundErr < 2 { + t.Fatalf("expected to keep error records, got %d", foundErr) + } +} + +func TestShouldRunToolCallsInParallel(t *testing.T) { + t.Parallel() + + al := &AgentLoop{ + parallelSafeTools: map[string]struct{}{ + "read_file": {}, + "memory_search": {}, + }, + } + ok := al.shouldRunToolCallsInParallel([]providers.ToolCall{ + {Name: "read_file"}, {Name: "memory_search"}, + }) + if !ok { + t.Fatalf("expected parallel-safe tools to run in parallel") + } + + notOK := al.shouldRunToolCallsInParallel([]providers.ToolCall{ + {Name: "read_file"}, {Name: "shell"}, + }) + if notOK { + t.Fatalf("expected mixed tool set to stay serial") + } +} + +func TestActToolCalls_ParallelExecutionForSafeTools(t *testing.T) { + t.Parallel() + + var active int32 + var maxActive int32 + probe := func() { + cur := atomic.AddInt32(&active, 1) + for { + old := atomic.LoadInt32(&maxActive) + if cur <= old || atomic.CompareAndSwapInt32(&maxActive, old, cur) { + break + } + } + time.Sleep(40 * time.Millisecond) + atomic.AddInt32(&active, -1) + } + + reg := tools.NewToolRegistry() + reg.Register(replayToolImpl{name: "read_file", run: func(ctx context.Context, args map[string]interface{}) (string, error) { + probe() + return "ok", nil + }}) + reg.Register(replayToolImpl{name: "memory_search", run: func(ctx context.Context, args map[string]interface{}) (string, error) { + probe() + return "ok", nil + }}) + + al := &AgentLoop{ + tools: reg, + sessions: session.NewSessionManager(""), + parallelSafeTools: map[string]struct{}{"read_file": {}, "memory_search": {}}, + maxParallelCalls: 2, + } + msgs := []providers.Message{} + calls := []providers.ToolCall{ + {ID: "1", Name: "read_file", Arguments: map[string]interface{}{}}, + {ID: "2", Name: "memory_search", Arguments: map[string]interface{}{}}, + } + + al.actToolCalls(context.Background(), "", calls, &msgs, "s1", 1, toolLoopBudget{ + maxCallsPerIteration: 2, + singleCallTimeout: 2 * time.Second, + maxActDuration: 2 * time.Second, + }, false, nil) + + if atomic.LoadInt32(&maxActive) < 2 { + t.Fatalf("expected concurrent execution, maxActive=%d", maxActive) + } +} + +func TestActToolCalls_ResourceConflictForcesSerial(t *testing.T) { + t.Parallel() + + var active int32 + var maxActive int32 + probe := func() { + cur := atomic.AddInt32(&active, 1) + for { + old := atomic.LoadInt32(&maxActive) + if cur <= old || atomic.CompareAndSwapInt32(&maxActive, old, cur) { + break + } + } + time.Sleep(35 * time.Millisecond) + atomic.AddInt32(&active, -1) + } + + yes := true + reg := tools.NewToolRegistry() + reg.Register(replayTool{ + name: "read_file", + parallelSafe: &yes, + resourceKeys: func(args map[string]interface{}) []string { return []string{"fs:/tmp/a"} }, + run: func(ctx context.Context, args map[string]interface{}) (string, error) { + probe() + return "ok", nil + }, + }) + reg.Register(replayTool{ + name: "memory_search", + parallelSafe: &yes, + resourceKeys: func(args map[string]interface{}) []string { return []string{"fs:/tmp/a"} }, + run: func(ctx context.Context, args map[string]interface{}) (string, error) { + probe() + return "ok", nil + }, + }) + + al := &AgentLoop{ + tools: reg, + sessions: session.NewSessionManager(""), + parallelSafeTools: map[string]struct{}{"read_file": {}, "memory_search": {}}, + maxParallelCalls: 2, + } + + msgs := []providers.Message{} + calls := []providers.ToolCall{ + {ID: "1", Name: "read_file", Arguments: map[string]interface{}{}}, + {ID: "2", Name: "memory_search", Arguments: map[string]interface{}{}}, + } + al.actToolCalls(context.Background(), "", calls, &msgs, "s1", 1, toolLoopBudget{ + maxCallsPerIteration: 2, + singleCallTimeout: 2 * time.Second, + maxActDuration: 2 * time.Second, + }, false, nil) + + if atomic.LoadInt32(&maxActive) > 1 { + t.Fatalf("expected serial execution on same resource key, maxActive=%d", maxActive) + } +} + +func TestLoadToolParallelPolicyFromConfig(t *testing.T) { + t.Parallel() + + allowed, maxCalls := loadToolParallelPolicyFromConfig(config.RuntimeControlConfig{ + ToolParallelSafeNames: []string{"Read_File", "memory_search"}, + ToolMaxParallelCalls: 3, + }) + if maxCalls != 3 { + t.Fatalf("unexpected max calls: %d", maxCalls) + } + if _, ok := allowed["read_file"]; !ok { + t.Fatalf("expected normalized read_file in allowed set") + } +} + +func TestShouldRunFinalizePolish(t *testing.T) { + t.Parallel() + + short := "done" + if shouldRunFinalizePolish(short) { + t.Fatalf("short draft should skip polish") + } + + longButFlat := strings.Repeat("a", finalizeDraftMinCharsForPolish+10) + if shouldRunFinalizePolish(longButFlat) { + t.Fatalf("flat draft should skip polish") + } + + longStructured := "1. Step one: check environment variables and baseline configs.\n2. Step two: apply fix and rerun validations.\nNext: verify rollout and provide follow-up actions." + if !shouldRunFinalizePolish(longStructured) { + t.Fatalf("structured draft should trigger polish") + } +} + +func TestLocalFinalizeDraftQualityScore(t *testing.T) { + t.Parallel() + + high := localFinalizeDraftQualityScore("1. Step one: inspect environment.\n2. Step two: apply fix.\nNext steps: validate rollout and summarize conclusions.") + low := localFinalizeDraftQualityScore("todo\ntodo\ntodo") + if high <= low { + t.Fatalf("expected high-quality score > low-quality score, got %.2f <= %.2f", high, low) + } + if high < 0.30 { + t.Fatalf("unexpectedly low high-quality score: %.2f", high) + } +} + +func TestClamp01(t *testing.T) { + t.Parallel() + + if got := clamp01(-0.1); got != 0 { + t.Fatalf("expected 0, got %v", got) + } + if got := clamp01(1.2); got != 1 { + t.Fatalf("expected 1, got %v", got) + } +} + +func TestInferLocalReflectionSignal(t *testing.T) { + t.Parallel() + + blocked := inferLocalReflectionSignal([]providers.Message{ + {Role: "tool", Content: "Error: permission denied"}, + {Role: "tool", Content: "Error: permission denied"}, + }) + if blocked.decision != "blocked" || blocked.uncertain { + t.Fatalf("expected blocked deterministic signal, got %+v", blocked) + } + + done := inferLocalReflectionSignal([]providers.Message{ + {Role: "tool", Content: "success: completed ok"}, + }) + if done.decision != "done" || done.uncertain { + t.Fatalf("expected done deterministic signal, got %+v", done) + } + + unknown := inferLocalReflectionSignal([]providers.Message{ + {Role: "tool", Content: "partial result"}, + }) + if unknown.decision != "continue" || !unknown.uncertain { + t.Fatalf("expected uncertain continue signal, got %+v", unknown) + } +} + +func TestShouldForceSelfRepairHeuristic(t *testing.T) { + t.Parallel() + + needs, prompt := shouldForceSelfRepairHeuristic("Please provide steps to fix this", "It should work.") + if !needs || strings.TrimSpace(prompt) == "" { + t.Fatalf("expected self-repair for missing structured steps") + } + + needs, _ = shouldForceSelfRepairHeuristic("summarize logs", "Here is summary.") + if needs { + t.Fatalf("did not expect repair for normal concise response") + } +} + +func TestSelfRepairMemoryPromptDedup(t *testing.T) { + t.Parallel() + + mem := selfRepairMemory{ + promptsUsed: map[string]struct{}{ + normalizeRepairPrompt("Provide structured step-by-step answer."): {}, + }, + } + if !promptSeen(mem, "provide structured step-by-step answer.") { + t.Fatalf("expected prompt to be detected as already used") + } + if promptSeen(mem, "different prompt") { + t.Fatalf("did not expect unrelated prompt to be marked used") + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index a0e95cb..d96c1b9 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1,7 +1,10 @@ package config import ( + "bytes" "encoding/json" + "fmt" + "io" "os" "path/filepath" "sync" @@ -35,6 +38,31 @@ type AgentDefaults struct { Temperature float64 `json:"temperature" env:"CLAWGO_AGENTS_DEFAULTS_TEMPERATURE"` MaxToolIterations int `json:"max_tool_iterations" env:"CLAWGO_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` ContextCompaction ContextCompactionConfig `json:"context_compaction"` + RuntimeControl RuntimeControlConfig `json:"runtime_control"` +} + +type RuntimeControlConfig struct { + IntentHighConfidence float64 `json:"intent_high_confidence" env:"CLAWGO_INTENT_HIGH_CONFIDENCE"` + IntentConfirmMinConfidence float64 `json:"intent_confirm_min_confidence" env:"CLAWGO_INTENT_CONFIRM_MIN_CONFIDENCE"` + IntentMaxInputChars int `json:"intent_max_input_chars" env:"CLAWGO_INTENT_MAX_INPUT_CHARS"` + ConfirmTTLSeconds int `json:"confirm_ttl_seconds" env:"CLAWGO_CONFIRM_TTL_SECONDS"` + ConfirmMaxClarificationTurns int `json:"confirm_max_clarification_turns" env:"CLAWGO_CONFIRM_MAX_CLARIFY_TURNS"` + AutonomyTickIntervalSec int `json:"autonomy_tick_interval_sec" env:"CLAWGO_AUTONOMY_TICK_INTERVAL_SEC"` + AutonomyMinRunIntervalSec int `json:"autonomy_min_run_interval_sec" env:"CLAWGO_AUTONOMY_MIN_RUN_INTERVAL_SEC"` + AutonomyIdleThresholdSec int `json:"autonomy_idle_threshold_sec" env:"CLAWGO_AUTONOMY_IDLE_THRESHOLD_SEC"` + AutonomyMaxRoundsWithoutUser int `json:"autonomy_max_rounds_without_user" env:"CLAWGO_AUTONOMY_MAX_ROUNDS_WITHOUT_USER"` + AutonomyMaxPendingDurationSec int `json:"autonomy_max_pending_duration_sec" env:"CLAWGO_AUTONOMY_MAX_PENDING_DURATION_SEC"` + AutonomyMaxConsecutiveStalls int `json:"autonomy_max_consecutive_stalls" env:"CLAWGO_AUTONOMY_MAX_STALLS"` + AutoLearnMaxRoundsWithoutUser int `json:"autolearn_max_rounds_without_user" env:"CLAWGO_AUTOLEARN_MAX_ROUNDS_WITHOUT_USER"` + RunStateTTLSeconds int `json:"run_state_ttl_seconds" env:"CLAWGO_RUN_STATE_TTL_SECONDS"` + RunStateMax int `json:"run_state_max" env:"CLAWGO_RUN_STATE_MAX"` + RunControlLatestKeywords []string `json:"run_control_latest_keywords"` + RunControlWaitKeywords []string `json:"run_control_wait_keywords"` + RunControlStatusKeywords []string `json:"run_control_status_keywords"` + RunControlRunMentionKeywords []string `json:"run_control_run_mention_keywords"` + RunControlMinuteUnits []string `json:"run_control_minute_units"` + ToolParallelSafeNames []string `json:"tool_parallel_safe_names"` + ToolMaxParallelCalls int `json:"tool_max_parallel_calls"` } type ContextCompactionConfig struct { @@ -246,6 +274,29 @@ func DefaultConfig() *Config { MaxSummaryChars: 6000, MaxTranscriptChars: 20000, }, + RuntimeControl: RuntimeControlConfig{ + IntentHighConfidence: 0.75, + IntentConfirmMinConfidence: 0.45, + IntentMaxInputChars: 1200, + ConfirmTTLSeconds: 300, + ConfirmMaxClarificationTurns: 2, + AutonomyTickIntervalSec: 20, + AutonomyMinRunIntervalSec: 20, + AutonomyIdleThresholdSec: 20, + AutonomyMaxRoundsWithoutUser: 120, + AutonomyMaxPendingDurationSec: 180, + AutonomyMaxConsecutiveStalls: 3, + AutoLearnMaxRoundsWithoutUser: 200, + RunStateTTLSeconds: 1800, + RunStateMax: 500, + RunControlLatestKeywords: []string{"latest", "last run", "recent run", "最新", "最近", "上一次", "上个"}, + RunControlWaitKeywords: []string{"wait", "等待", "等到", "阻塞"}, + RunControlStatusKeywords: []string{"status", "状态", "进度", "running", "运行"}, + RunControlRunMentionKeywords: []string{"run", "任务"}, + RunControlMinuteUnits: []string{"分钟", "min", "mins", "minute", "minutes", "m"}, + ToolParallelSafeNames: []string{"read_file", "list_files", "find_files", "grep_files", "memory_search", "web_search", "repo_map", "system_info"}, + ToolMaxParallelCalls: 2, + }, }, }, Channels: ChannelsConfig{ @@ -382,7 +433,7 @@ func LoadConfig(path string) (*Config, error) { return nil, err } - if err := json.Unmarshal(data, cfg); err != nil { + if err := unmarshalConfigStrict(data, cfg); err != nil { return nil, err } @@ -393,6 +444,22 @@ func LoadConfig(path string) (*Config, error) { return cfg, nil } +func unmarshalConfigStrict(data []byte, cfg *Config) error { + dec := json.NewDecoder(bytes.NewReader(data)) + dec.DisallowUnknownFields() + if err := dec.Decode(cfg); err != nil { + return err + } + var extra json.RawMessage + if err := dec.Decode(&extra); err != io.EOF { + if err == nil { + return fmt.Errorf("invalid config: trailing JSON content") + } + return err + } + return nil +} + func SaveConfig(path string, cfg *Config) error { cfg.mu.RLock() defer cfg.mu.RUnlock() diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go new file mode 100644 index 0000000..15d0789 --- /dev/null +++ b/pkg/config/config_test.go @@ -0,0 +1,95 @@ +package config + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestLoadConfigRejectsUnknownField(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.json") + content := `{ + "agents": { + "defaults": { + "runtime_control": { + "intent_high_confidence": 0.8, + "unknown_field": 1 + } + } + } +}` + if err := os.WriteFile(cfgPath, []byte(content), 0o644); err != nil { + t.Fatalf("write config: %v", err) + } + + _, err := LoadConfig(cfgPath) + if err == nil { + t.Fatalf("expected unknown field error") + } + if !strings.Contains(strings.ToLower(err.Error()), "unknown field") { + t.Fatalf("expected unknown field error, got: %v", err) + } +} + +func TestLoadConfigRejectsTrailingJSONContent(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.json") + content := `{"agents":{"defaults":{"runtime_control":{"intent_high_confidence":0.8}}}}{"extra":true}` + if err := os.WriteFile(cfgPath, []byte(content), 0o644); err != nil { + t.Fatalf("write config: %v", err) + } + + _, err := LoadConfig(cfgPath) + if err == nil { + t.Fatalf("expected trailing json content error") + } + if !strings.Contains(err.Error(), "trailing JSON content") { + t.Fatalf("expected trailing JSON content error, got: %v", err) + } +} + +func TestLoadConfigAllowsKnownRuntimeControlFields(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.json") + content := `{ + "agents": { + "defaults": { + "runtime_control": { + "intent_high_confidence": 0.88, + "run_state_max": 321, + "run_control_wait_keywords": ["wait", "block"], + "tool_parallel_safe_names": ["read_file", "memory_search"], + "tool_max_parallel_calls": 3 + } + } + } +}` + if err := os.WriteFile(cfgPath, []byte(content), 0o644); err != nil { + t.Fatalf("write config: %v", err) + } + + cfg, err := LoadConfig(cfgPath) + if err != nil { + t.Fatalf("load config: %v", err) + } + if got := cfg.Agents.Defaults.RuntimeControl.IntentHighConfidence; got != 0.88 { + t.Fatalf("intent_high_confidence mismatch: got %.2f", got) + } + if got := cfg.Agents.Defaults.RuntimeControl.RunStateMax; got != 321 { + t.Fatalf("run_state_max mismatch: got %d", got) + } + if got := len(cfg.Agents.Defaults.RuntimeControl.RunControlWaitKeywords); got != 2 { + t.Fatalf("run_control_wait_keywords mismatch: got %d", got) + } + if got := cfg.Agents.Defaults.RuntimeControl.ToolMaxParallelCalls; got != 3 { + t.Fatalf("tool_max_parallel_calls mismatch: got %d", got) + } +} diff --git a/pkg/config/validate.go b/pkg/config/validate.go index 0e9b8ee..1071566 100644 --- a/pkg/config/validate.go +++ b/pkg/config/validate.go @@ -17,6 +17,58 @@ func Validate(cfg *Config) []error { if cfg.Agents.Defaults.MaxToolIterations <= 0 { errs = append(errs, fmt.Errorf("agents.defaults.max_tool_iterations must be > 0")) } + rc := cfg.Agents.Defaults.RuntimeControl + if rc.IntentHighConfidence <= 0 || rc.IntentHighConfidence > 1 { + errs = append(errs, fmt.Errorf("agents.defaults.runtime_control.intent_high_confidence must be in (0,1]")) + } + if rc.IntentConfirmMinConfidence < 0 || rc.IntentConfirmMinConfidence >= rc.IntentHighConfidence { + errs = append(errs, fmt.Errorf("agents.defaults.runtime_control.intent_confirm_min_confidence must be >= 0 and < intent_high_confidence")) + } + if rc.IntentMaxInputChars < 200 { + errs = append(errs, fmt.Errorf("agents.defaults.runtime_control.intent_max_input_chars must be >= 200")) + } + if rc.ConfirmTTLSeconds <= 0 { + errs = append(errs, fmt.Errorf("agents.defaults.runtime_control.confirm_ttl_seconds must be > 0")) + } + if rc.ConfirmMaxClarificationTurns < 0 { + errs = append(errs, fmt.Errorf("agents.defaults.runtime_control.confirm_max_clarification_turns must be >= 0")) + } + if rc.AutonomyTickIntervalSec < 5 { + errs = append(errs, fmt.Errorf("agents.defaults.runtime_control.autonomy_tick_interval_sec must be >= 5")) + } + if rc.AutonomyMinRunIntervalSec < 5 { + errs = append(errs, fmt.Errorf("agents.defaults.runtime_control.autonomy_min_run_interval_sec must be >= 5")) + } + if rc.AutonomyIdleThresholdSec < 5 { + errs = append(errs, fmt.Errorf("agents.defaults.runtime_control.autonomy_idle_threshold_sec must be >= 5")) + } + if rc.AutonomyMaxRoundsWithoutUser <= 0 { + errs = append(errs, fmt.Errorf("agents.defaults.runtime_control.autonomy_max_rounds_without_user must be > 0")) + } + if rc.AutonomyMaxPendingDurationSec < 10 { + errs = append(errs, fmt.Errorf("agents.defaults.runtime_control.autonomy_max_pending_duration_sec must be >= 10")) + } + if rc.AutonomyMaxConsecutiveStalls <= 0 { + errs = append(errs, fmt.Errorf("agents.defaults.runtime_control.autonomy_max_consecutive_stalls must be > 0")) + } + if rc.AutoLearnMaxRoundsWithoutUser <= 0 { + errs = append(errs, fmt.Errorf("agents.defaults.runtime_control.autolearn_max_rounds_without_user must be > 0")) + } + if rc.RunStateTTLSeconds < 60 { + errs = append(errs, fmt.Errorf("agents.defaults.runtime_control.run_state_ttl_seconds must be >= 60")) + } + if rc.RunStateMax <= 0 { + errs = append(errs, fmt.Errorf("agents.defaults.runtime_control.run_state_max must be > 0")) + } + errs = append(errs, validateNonEmptyStringList("agents.defaults.runtime_control.run_control_latest_keywords", rc.RunControlLatestKeywords)...) + errs = append(errs, validateNonEmptyStringList("agents.defaults.runtime_control.run_control_wait_keywords", rc.RunControlWaitKeywords)...) + errs = append(errs, validateNonEmptyStringList("agents.defaults.runtime_control.run_control_status_keywords", rc.RunControlStatusKeywords)...) + errs = append(errs, validateNonEmptyStringList("agents.defaults.runtime_control.run_control_run_mention_keywords", rc.RunControlRunMentionKeywords)...) + errs = append(errs, validateNonEmptyStringList("agents.defaults.runtime_control.run_control_minute_units", rc.RunControlMinuteUnits)...) + errs = append(errs, validateNonEmptyStringList("agents.defaults.runtime_control.tool_parallel_safe_names", rc.ToolParallelSafeNames)...) + if rc.ToolMaxParallelCalls <= 0 { + errs = append(errs, fmt.Errorf("agents.defaults.runtime_control.tool_max_parallel_calls must be > 0")) + } if cfg.Agents.Defaults.ContextCompaction.Enabled { cc := cfg.Agents.Defaults.ContextCompaction if cc.Mode != "" { @@ -199,3 +251,16 @@ func providerConfigByName(cfg *Config, name string) (ProviderConfig, bool) { pc, ok := cfg.Providers.Proxies[name] return pc, ok } + +func validateNonEmptyStringList(path string, values []string) []error { + if len(values) == 0 { + return nil + } + var errs []error + for i, value := range values { + if strings.TrimSpace(value) == "" { + errs = append(errs, fmt.Errorf("%s[%d] must not be empty", path, i)) + } + } + return errs +} diff --git a/pkg/tools/base.go b/pkg/tools/base.go index 1bf53f7..8fea049 100644 --- a/pkg/tools/base.go +++ b/pkg/tools/base.go @@ -9,6 +9,20 @@ type Tool interface { Execute(ctx context.Context, args map[string]interface{}) (string, error) } +// ParallelSafeTool is an optional capability interface. +// If implemented by a tool, AgentLoop should trust this declaration +// over name-based whitelist when deciding parallel execution safety. +type ParallelSafeTool interface { + ParallelSafe() bool +} + +// ResourceScopedTool is an optional capability interface. +// If implemented by a tool, AgentLoop can avoid running calls that touch +// the same resource keys in parallel. +type ResourceScopedTool interface { + ResourceKeys(args map[string]interface{}) []string +} + func ToolToSchema(tool Tool) map[string]interface{} { return map[string]interface{}{ "type": "function",