diff --git a/cmd/cmd_gateway.go b/cmd/cmd_gateway.go index d2adee4..74c495a 100644 --- a/cmd/cmd_gateway.go +++ b/cmd/cmd_gateway.go @@ -151,8 +151,11 @@ func gatewayCmd() { } return loop.ProcessDirect(cctx, content, sessionKey) }) - registryServer.SetChatHistoryHandler(func(sessionKey string) []map[string]interface{} { - h := loop.GetSessionHistory(sessionKey) + registryServer.SetChatHistoryHandler(func(query api.ChatHistoryQuery) []map[string]interface{} { + h := loop.GetSessionHistory(query.Session) + if query.Around > 0 || query.Before > 0 || query.After > 0 || query.Limit > 0 { + h = loop.GetSessionHistoryWindow(query.Session, query.Around, query.Before, query.After, query.Limit) + } out := make([]map[string]interface{}, 0, len(h)) for _, m := range h { entry := map[string]interface{}{"role": m.Role, "content": m.Content} @@ -166,6 +169,35 @@ func gatewayCmd() { } return out }) + registryServer.SetSessionSearchHandler(func(query api.SessionSearchQuery) []map[string]interface{} { + excludeKey := "" + if query.ExcludeCurrent { + excludeKey = strings.TrimSpace(query.Session) + } + results := loop.SearchSessions(query.Query, query.Kinds, excludeKey, query.Limit) + out := make([]map[string]interface{}, 0, len(results)) + for _, item := range results { + entry := map[string]interface{}{ + "key": item.Key, + "kind": item.Kind, + "updated_at": item.UpdatedAt.UnixMilli(), + "summary": item.Summary, + "score": item.Score, + } + snippets := make([]map[string]interface{}, 0, len(item.Snippets)) + for _, snippet := range item.Snippets { + snippets = append(snippets, map[string]interface{}{ + "seq": snippet.Seq, + "role": snippet.Role, + "segment": snippet.Segment, + "content": snippet.Content, + }) + } + entry["snippets"] = snippets + out = append(out, entry) + } + return out + }) registryServer.SetToolsCatalogHandler(func() interface{} { return loop.GetToolCatalog() }) diff --git a/cmd/cmd_onboard.go b/cmd/cmd_onboard.go index 6e15d05..2050587 100644 --- a/cmd/cmd_onboard.go +++ b/cmd/cmd_onboard.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" + clawgoassets "github.com/YspCoder/clawgo" "github.com/YspCoder/clawgo/pkg/config" ) @@ -42,7 +43,7 @@ func copyEmbeddedToTarget(targetDir string, overwrite func(relPath string) bool) return fmt.Errorf("failed to create target directory: %w", err) } - return fs.WalkDir(embeddedFiles, "workspace", func(path string, d fs.DirEntry, err error) error { + return fs.WalkDir(clawgoassets.WorkspaceTemplates, "workspace", func(path string, d fs.DirEntry, err error) error { if err != nil { return err } @@ -50,7 +51,7 @@ func copyEmbeddedToTarget(targetDir string, overwrite func(relPath string) bool) return nil } - data, err := embeddedFiles.ReadFile(path) + data, err := clawgoassets.WorkspaceTemplates.ReadFile(path) if err != nil { return fmt.Errorf("failed to read embedded file %s: %w", path, err) } diff --git a/cmd/main.go b/cmd/main.go index 6e4e429..69c85bc 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -7,7 +7,6 @@ package main import ( - "embed" "errors" "fmt" "os" @@ -16,9 +15,6 @@ import ( "github.com/YspCoder/clawgo/pkg/logger" ) -//go:embed workspace -var embeddedFiles embed.FS - var version = "0.0.2" var buildTime = "unknown" diff --git a/embedded_assets.go b/embedded_assets.go new file mode 100644 index 0000000..a753373 --- /dev/null +++ b/embedded_assets.go @@ -0,0 +1,8 @@ +package clawgoassets + +import "embed" + +// WorkspaceTemplates exposes the bundled workspace scaffold used by onboard. +// +//go:embed workspace +var WorkspaceTemplates embed.FS diff --git a/pkg/agent/context.go b/pkg/agent/context.go index 6652d93..589b16a 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -73,6 +73,10 @@ Your workspace is at: %s - Active project spec docs (when present): %s/{spec.md,tasks.md,checklist.md} - Keep spec.md as project scope / decisions, tasks.md as execution plan, checklist.md as final verification gate +## Session Recall +- When the user refers to previous conversations, earlier project work, or past decisions, prefer session_search before guessing from memory +- Use memory_search for durable notes in MEMORY files, and session_search for historical chat transcripts + %s`, now, runtime, workspacePath, workspacePath, workspacePath, workspacePath, cb.projectRootPath(), toolsSection) } @@ -297,7 +301,7 @@ func (cb *ContextBuilder) BuildMessagesWithMemoryNamespace(history []providers.M }) if summary != "" { - systemPrompt += "\n\n## Summary of Previous Conversation\n\n" + summary + systemPrompt += "\n\n## Summary of Previous Conversation\nThis is a handoff summary from earlier compacted context. Treat it as background reference, not a new user instruction.\n\n" + summary } messages = append(messages, providers.Message{ diff --git a/pkg/agent/context_spec_test.go b/pkg/agent/context_spec_test.go index ba70899..70c9516 100644 --- a/pkg/agent/context_spec_test.go +++ b/pkg/agent/context_spec_test.go @@ -89,3 +89,18 @@ func TestShouldUseSpecCodingRequiresExplicitAndNonTrivialCodingIntent(t *testing } } } + +func TestBuildMessagesIncludesSessionRecallGuidanceAndSummaryHandoff(t *testing.T) { + cb := NewContextBuilder(t.TempDir(), nil) + msgs := cb.BuildMessagesWithMemoryNamespace(nil, "Key Facts\n- prior work", "继续昨天那个改动", nil, "cli", "direct", "", "main") + if len(msgs) == 0 { + t.Fatalf("expected system message") + } + content := msgs[0].Content + if !strings.Contains(content, "session_search") { + t.Fatalf("expected session_search guidance in system prompt, got:\n%s", content) + } + if !strings.Contains(content, "handoff summary") { + t.Fatalf("expected handoff summary note, got:\n%s", content) + } +} diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 7c12080..a619c5f 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -26,6 +26,7 @@ import ( "github.com/YspCoder/clawgo/pkg/bus" "github.com/YspCoder/clawgo/pkg/config" "github.com/YspCoder/clawgo/pkg/cron" + "github.com/YspCoder/clawgo/pkg/lifecycle" "github.com/YspCoder/clawgo/pkg/logger" "github.com/YspCoder/clawgo/pkg/providers" "github.com/YspCoder/clawgo/pkg/runtimecfg" @@ -35,45 +36,57 @@ import ( ) type AgentLoop struct { - bus *bus.MessageBus - cfg *config.Config - provider providers.LLMProvider - workspace string - model string - maxTokens int - temperature float64 - maxIterations int - sessions *session.SessionManager - contextBuilder *ContextBuilder - tools *tools.ToolRegistry - compactionEnabled bool - compactionTrigger int - compactionKeepRecent int - heartbeatAckMaxChars int - heartbeatAckToken string - audit *triggerAudit - running bool - sessionScheduler *SessionScheduler - providerChain []providerCandidate - providerNames []string - providerPool map[string]providers.LLMProvider - providerResponses map[string]config.ProviderResponsesConfig - providerMaxTokens map[string]int - providerTemperatures map[string]float64 - telegramStreaming bool - providerMu sync.RWMutex - sessionProvider map[string]string - streamMu sync.Mutex - sessionStreamed map[string]bool - subagentManager *tools.SubagentManager - subagentRouter *tools.SubagentRouter - configPath string - subagentDigestMu sync.Mutex - subagentDigestDelay time.Duration - subagentDigests map[string]*subagentDigestState - runMu sync.Mutex - runCancel context.CancelFunc - runWG sync.WaitGroup + bus *bus.MessageBus + cfg *config.Config + provider providers.LLMProvider + workspace string + model string + maxTokens int + temperature float64 + maxIterations int + sessions *session.SessionManager + contextBuilder *ContextBuilder + tools *tools.ToolRegistry + compactionEnabled bool + compactionTrigger int + compactionKeepRecent int + compactionProtectLastN int + compactionTargetRatio float64 + compactionPressureThreshold float64 + compactionMaxSummaryChars int + compactionMaxTranscriptChars int + heartbeatAckMaxChars int + heartbeatAckToken string + audit *triggerAudit + running bool + sessionScheduler *SessionScheduler + providerChain []providerCandidate + providerNames []string + providerPool map[string]providers.LLMProvider + providerResponses map[string]config.ProviderResponsesConfig + providerMaxTokens map[string]int + providerTemperatures map[string]float64 + telegramStreaming bool + providerMu sync.RWMutex + sessionProvider map[string]string + streamMu sync.Mutex + sessionStreamed map[string]bool + subagentManager *tools.SubagentManager + subagentRouter *tools.SubagentRouter + configPath string + subagentDigestMu sync.Mutex + subagentDigestDelay time.Duration + subagentDigests map[string]*subagentDigestState + compactionRunner *lifecycle.LoopRunner + compactionMu sync.Mutex + compactionSignal chan struct{} + compactionQueue []string + compactionQueued map[string]struct{} + compactionInflight map[string]struct{} + compactionDirty map[string]struct{} + runMu sync.Mutex + runCancel context.CancelFunc + runWG sync.WaitGroup } type providerCandidate struct { @@ -198,6 +211,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers return h }, )) + toolsRegistry.Register(tools.NewSessionSearchTool(sessionsManager)) // Register edit file tool editFileTool := tools.NewEditFileTool(workspace) @@ -221,35 +235,45 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers toolsRegistry.Register(tools.NewSystemInfoTool()) loop := &AgentLoop{ - bus: msgBus, - cfg: cfg, - provider: provider, - workspace: workspace, - model: provider.GetDefaultModel(), - maxTokens: cfg.Agents.Defaults.MaxTokens, - temperature: cfg.Agents.Defaults.Temperature, - maxIterations: cfg.Agents.Defaults.MaxToolIterations, - sessions: sessionsManager, - contextBuilder: NewContextBuilder(workspace, func() []string { return toolsRegistry.GetSummaries() }), - tools: toolsRegistry, - compactionEnabled: cfg.Agents.Defaults.ContextCompaction.Enabled, - compactionTrigger: cfg.Agents.Defaults.ContextCompaction.TriggerMessages, - compactionKeepRecent: cfg.Agents.Defaults.ContextCompaction.KeepRecentMessages, - heartbeatAckMaxChars: cfg.Agents.Defaults.Heartbeat.AckMaxChars, - heartbeatAckToken: loadHeartbeatAckToken(workspace), - audit: newTriggerAudit(workspace), - running: false, - sessionScheduler: NewSessionScheduler(0), - sessionProvider: map[string]string{}, - sessionStreamed: map[string]bool{}, - providerResponses: map[string]config.ProviderResponsesConfig{}, - providerMaxTokens: map[string]int{}, - providerTemperatures: map[string]float64{}, - telegramStreaming: cfg.Channels.Telegram.Streaming, - subagentManager: subagentManager, - subagentRouter: subagentRouter, - subagentDigestDelay: 5 * time.Second, - subagentDigests: map[string]*subagentDigestState{}, + bus: msgBus, + cfg: cfg, + provider: provider, + workspace: workspace, + model: provider.GetDefaultModel(), + maxTokens: cfg.Agents.Defaults.MaxTokens, + temperature: cfg.Agents.Defaults.Temperature, + maxIterations: cfg.Agents.Defaults.MaxToolIterations, + sessions: sessionsManager, + contextBuilder: NewContextBuilder(workspace, func() []string { return toolsRegistry.GetSummaries() }), + tools: toolsRegistry, + compactionEnabled: cfg.Agents.Defaults.ContextCompaction.Enabled, + compactionTrigger: cfg.Agents.Defaults.ContextCompaction.TriggerMessages, + compactionKeepRecent: cfg.Agents.Defaults.ContextCompaction.KeepRecentMessages, + compactionProtectLastN: normalizePositiveInt(cfg.Agents.Defaults.ContextCompaction.ProtectLastN, 12), + compactionTargetRatio: normalizeCompactionRatio(cfg.Agents.Defaults.ContextCompaction.TargetRatio, 0.35), + compactionPressureThreshold: normalizeCompactionRatio(cfg.Agents.Defaults.ContextCompaction.PressureThreshold, 0.8), + compactionMaxSummaryChars: normalizePositiveInt(cfg.Agents.Defaults.ContextCompaction.MaxSummaryChars, 6000), + compactionMaxTranscriptChars: normalizePositiveInt(cfg.Agents.Defaults.ContextCompaction.MaxTranscriptChars, 20000), + heartbeatAckMaxChars: cfg.Agents.Defaults.Heartbeat.AckMaxChars, + heartbeatAckToken: loadHeartbeatAckToken(workspace), + audit: newTriggerAudit(workspace), + running: false, + sessionScheduler: NewSessionScheduler(0), + sessionProvider: map[string]string{}, + sessionStreamed: map[string]bool{}, + providerResponses: map[string]config.ProviderResponsesConfig{}, + providerMaxTokens: map[string]int{}, + providerTemperatures: map[string]float64{}, + telegramStreaming: cfg.Channels.Telegram.Streaming, + subagentManager: subagentManager, + subagentRouter: subagentRouter, + subagentDigestDelay: 5 * time.Second, + subagentDigests: map[string]*subagentDigestState{}, + compactionRunner: lifecycle.NewLoopRunner(), + compactionSignal: make(chan struct{}, 1), + compactionQueued: map[string]struct{}{}, + compactionInflight: map[string]struct{}{}, + compactionDirty: map[string]struct{}{}, } if _, primaryModel := config.ParseProviderModelRef(cfg.Agents.Defaults.Model.Primary); strings.TrimSpace(primaryModel) != "" { loop.model = strings.TrimSpace(primaryModel) @@ -335,6 +359,9 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers if run == nil { return "", fmt.Errorf("subagent run is nil") } + if run.MaxToolIterations <= 0 { + run.MaxToolIterations = loop.maxIterations + } sessionKey := strings.TrimSpace(run.SessionKey) if sessionKey == "" { sessionKey = fmt.Sprintf("subagent:%s", strings.TrimSpace(run.ID)) @@ -457,6 +484,9 @@ func (al *AgentLoop) Stop() { } al.running = false al.runWG.Wait() + if al.compactionRunner != nil { + al.compactionRunner.Stop() + } } func (al *AgentLoop) buildSessionShards(ctx context.Context) []chan bus.InboundMessage { @@ -476,11 +506,12 @@ func (al *AgentLoop) buildSessionShards(ctx context.Context) []chan bus.InboundM return shards } -func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMessage, messages []providers.Message, toolDefs []providers.ToolDefinition, primaryErr error) (*providers.LLMResponse, string, error) { +func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMessage, messages []providers.Message, toolDefs []providers.ToolDefinition, primaryErr error) (*providers.LLMResponse, string, int, error) { if len(al.providerChain) <= 1 { - return nil, "", primaryErr + return nil, "", 0, primaryErr } lastErr := primaryErr + attempts := 0 candidateNames := make([]string, 0, len(al.providerChain)-1) for _, candidate := range al.providerChain[1:] { candidateNames = append(candidateNames, candidate.name) @@ -507,16 +538,17 @@ func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMe lastErr = err continue } + attempts++ fallbackOptions := al.buildResponsesOptionsForProvider(msg.SessionKey, candidate.name, int64(al.maxTokensForProvider(candidate.name)), al.temperatureForProvider(candidate.name)) resp, err := p.Chat(ctx, messages, toolDefs, candidateModel, fallbackOptions) if err == nil { al.setSessionProvider(msg.SessionKey, candidate.name) logger.WarnCF("agent", logger.C0150, map[string]interface{}{"provider": candidate.name, "model": candidateModel, "ref": candidate.ref}) - return resp, candidate.name, nil + return resp, candidate.name, attempts, nil } lastErr = err } - return nil, "", lastErr + return nil, "", attempts, lastErr } func (al *AgentLoop) ensureProviderCandidate(candidate providerCandidate) (providers.LLMProvider, string, error) { @@ -865,6 +897,9 @@ type llmTurnLoopResult struct { pendingPersist []providers.Message finalContent string iteration int + attemptCount int + restartCount int + failureCode string hasToolActivity bool } @@ -960,39 +995,150 @@ func (al *AgentLoop) executeResponseToolCalls(cfg llmTurnLoopConfig, iteration i return results } -func (al *AgentLoop) requestLLMResponse(cfg llmTurnLoopConfig, activeProvider providers.LLMProvider, activeModel string, messages []providers.Message, providerToolDefs []providers.ToolDefinition, options map[string]interface{}) (*providers.LLMResponse, error) { +func (al *AgentLoop) requestLLMResponse(cfg llmTurnLoopConfig, activeProvider providers.LLMProvider, activeModel string, messages []providers.Message, providerToolDefs []providers.ToolDefinition, options map[string]interface{}) (*providers.LLMResponse, int, error) { if cfg.enableStreaming { if sp, ok := activeProvider.(providers.StreamingLLMProvider); ok { - streamText := "" - lastPush := time.Now().Add(-time.Second) - return sp.ChatStream(cfg.ctx, messages, providerToolDefs, activeModel, options, func(delta string) { - if strings.TrimSpace(delta) == "" { - return - } - streamText += delta - if time.Since(lastPush) < 450*time.Millisecond { - return - } - if !shouldFlushTelegramStreamSnapshot(streamText) { - return - } - lastPush = time.Now() - replyID := "" - if cfg.triggerMsg.Metadata != nil { - replyID = cfg.triggerMsg.Metadata["message_id"] - } - al.bus.PublishOutbound(bus.OutboundMessage{ - Channel: cfg.toolChannel, - ChatID: cfg.toolChatID, - Content: streamText, - Action: "stream", - ReplyToID: replyID, - }) - al.markSessionStreamed(cfg.sessionKey) - }) + return al.requestStreamingLLMResponse(cfg, sp, activeProvider, activeModel, messages, providerToolDefs, options) } } - return activeProvider.Chat(cfg.ctx, messages, providerToolDefs, activeModel, options) + resp, err := activeProvider.Chat(cfg.ctx, messages, providerToolDefs, activeModel, options) + return resp, 1, err +} + +func (al *AgentLoop) requestStreamingLLMResponse(cfg llmTurnLoopConfig, streamer providers.StreamingLLMProvider, fallbackProvider providers.LLMProvider, activeModel string, messages []providers.Message, providerToolDefs []providers.ToolDefinition, options map[string]interface{}) (*providers.LLMResponse, int, error) { + streamText := "" + lastPush := time.Now().Add(-time.Second) + type streamResult struct { + resp *providers.LLMResponse + err error + } + streamCtx, cancel := context.WithCancel(cfg.ctx) + defer cancel() + resultCh := make(chan streamResult, 1) + var deltaMu sync.Mutex + firstDeltaSeen := false + lastDeltaAt := time.Now() + go func() { + resp, err := streamer.ChatStream(streamCtx, messages, providerToolDefs, activeModel, options, func(delta string) { + if strings.TrimSpace(delta) == "" { + return + } + deltaMu.Lock() + firstDeltaSeen = true + lastDeltaAt = time.Now() + deltaMu.Unlock() + streamText += delta + if time.Since(lastPush) < 450*time.Millisecond { + return + } + if !shouldFlushTelegramStreamSnapshot(streamText) { + return + } + lastPush = time.Now() + replyID := "" + if cfg.triggerMsg.Metadata != nil { + replyID = cfg.triggerMsg.Metadata["message_id"] + } + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: cfg.toolChannel, + ChatID: cfg.toolChatID, + Content: streamText, + Action: "stream", + ReplyToID: replyID, + }) + al.markSessionStreamed(cfg.sessionKey) + }) + resultCh <- streamResult{resp: resp, err: err} + }() + + firstDeltaTimeout, idleDeltaTimeout := streamMonitorTimeouts(cfg.ctx) + staleTriggered := false + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + for { + select { + case result := <-resultCh: + if result.err == nil { + return result.resp, 1, nil + } + deltaMu.Lock() + streamStarted := firstDeltaSeen + deltaMu.Unlock() + if staleTriggered && streamStarted { + return nil, 1, providers.NewProviderExecutionError("stream_stale", result.err.Error(), "stream", true, "") + } + if !streamStarted { + resp, err := fallbackProvider.Chat(cfg.ctx, messages, providerToolDefs, activeModel, options) + return resp, 2, err + } + return nil, 1, result.err + case <-ticker.C: + deltaMu.Lock() + streamStarted := firstDeltaSeen + idleFor := time.Since(lastDeltaAt) + deltaMu.Unlock() + timeout := firstDeltaTimeout + if streamStarted { + timeout = idleDeltaTimeout + } + if timeout > 0 && idleFor > timeout { + staleTriggered = true + cancel() + } + case <-cfg.ctx.Done(): + cancel() + return nil, 1, cfg.ctx.Err() + } + } +} + +func streamMonitorTimeouts(ctx context.Context) (time.Duration, time.Duration) { + first := 20 * time.Second + idle := 45 * time.Second + if deadline, ok := ctx.Deadline(); ok { + remaining := time.Until(deadline) + if remaining > 0 { + firstMin := 5 * time.Second + idleMin := 15 * time.Second + if remaining < firstMin { + firstMin = maxDuration(100*time.Millisecond, remaining/4) + } + if remaining < idleMin { + idleMin = maxDuration(250*time.Millisecond, remaining/3) + } + first = clampDuration(remaining/4, firstMin, first) + idle = clampDuration(remaining/3, idleMin, idle) + } + } + return first, idle +} + +func clampDuration(value, min, max time.Duration) time.Duration { + if value < min { + return min + } + if max > 0 && value > max { + return max + } + return value +} + +func maxDuration(a, b time.Duration) time.Duration { + if a > b { + return a + } + return b +} + +func (al *AgentLoop) maxIterationsForContext(ctx context.Context) int { + limit := al.maxIterations + if limit < 1 { + limit = 1 + } + if budget, ok := tools.SubagentIterationBudget(ctx); ok && budget > 0 { + return budget + } + return limit } func (al *AgentLoop) runLLMTurnLoop(cfg llmTurnLoopConfig) (llmTurnLoopResult, error) { @@ -1000,7 +1146,7 @@ func (al *AgentLoop) runLLMTurnLoop(cfg llmTurnLoopConfig) (llmTurnLoopResult, e messages: append([]providers.Message(nil), cfg.messages...), pendingPersist: make([]providers.Message, 0, 16), } - maxAllowed := al.maxIterations + maxAllowed := al.maxIterationsForContext(cfg.ctx) if maxAllowed < 1 { maxAllowed = 1 } @@ -1027,14 +1173,19 @@ func (al *AgentLoop) runLLMTurnLoop(cfg llmTurnLoopConfig) (llmTurnLoopResult, e logLLMTurnRequest(result.iteration, al.maxIterations, providerName, activeModel, result.messages, providerToolDefs, maxTokens, temperature) options := al.buildResponsesOptions(cfg.sessionKey, int64(maxTokens), temperature) - response, err := al.requestLLMResponse(cfg, activeProvider, activeModel, result.messages, providerToolDefs, options) + response, attempts, err := al.requestLLMResponse(cfg, activeProvider, activeModel, result.messages, providerToolDefs, options) + result.attemptCount += attempts if err != nil { - if fb, _, ferr := al.tryFallbackProviders(cfg.ctx, cfg.triggerMsg, result.messages, providerToolDefs, err); ferr == nil && fb != nil { + result.failureCode = classifyLLMFailureCode(err) + if fb, _, fallbackAttempts, ferr := al.tryFallbackProviders(cfg.ctx, cfg.triggerMsg, result.messages, providerToolDefs, err); ferr == nil && fb != nil { response = fb + result.attemptCount += fallbackAttempts err = nil } else { + result.attemptCount += fallbackAttempts err = ferr + result.failureCode = classifyLLMFailureCode(err) } } if err != nil { @@ -1059,9 +1210,6 @@ func (al *AgentLoop) runLLMTurnLoop(cfg llmTurnLoopConfig) (llmTurnLoopResult, e result.messages = append(result.messages, assistantMsg) result.pendingPersist = append(result.pendingPersist, assistantMsg) result.hasToolActivity = true - if maxAllowed < result.iteration+al.maxIterations { - maxAllowed = result.iteration + al.maxIterations - } for _, toolResultMsg := range al.executeResponseToolCalls(cfg, result.iteration, response) { result.messages = append(result.messages, toolResultMsg) @@ -1069,7 +1217,8 @@ func (al *AgentLoop) runLLMTurnLoop(cfg llmTurnLoopConfig) (llmTurnLoopResult, e } } - return result, nil + result.failureCode = "retry_limit" + return result, fmt.Errorf("max tool iterations exceeded (%d)", maxAllowed) } func (al *AgentLoop) logInboundMessageStart(msg bus.InboundMessage) { @@ -1083,7 +1232,7 @@ func (al *AgentLoop) logInboundMessageStart(msg bus.InboundMessage) { } func (al *AgentLoop) prepareUserMessageContext(msg bus.InboundMessage, memoryNamespace string) ([]providers.Message, string) { - history := al.sessions.GetHistory(msg.SessionKey) + history := al.sessions.GetPromptHistory(msg.SessionKey) summary := al.sessions.GetSummary(msg.SessionKey) al.sessions.AddMessage(msg.SessionKey, "user", msg.Content) if explicitPref := ExtractLanguagePreference(msg.Content); explicitPref != "" { @@ -1113,12 +1262,11 @@ func (al *AgentLoop) finalizeUserMessage(sessionKey, responseLang string, pendin Content: finalContent, }) al.sessions.SetLastLanguage(sessionKey, responseLang) - al.compactSessionIfNeeded(sessionKey) - _ = al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) + al.enqueueSessionCompaction(sessionKey) } func (al *AgentLoop) prepareSystemMessageContext(sessionKey string, msg bus.InboundMessage, originChannel, originChatID string) ([]providers.Message, string) { - history := al.sessions.GetHistory(sessionKey) + history := al.sessions.GetPromptHistory(sessionKey) summary := al.sessions.GetSummary(sessionKey) preferredLang, lastLang := al.sessions.GetLanguagePreferences(sessionKey) responseLang := DetectResponseLanguage(msg.Content, preferredLang, lastLang) @@ -1144,8 +1292,7 @@ func (al *AgentLoop) finalizeSystemMessage(sessionKey, responseLang string, msg Content: finalContent, }) al.sessions.SetLastLanguage(sessionKey, responseLang) - al.compactSessionIfNeeded(sessionKey) - _ = al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) + al.enqueueSessionCompaction(sessionKey) } func (al *AgentLoop) startSpecTaskForMessage(msg bus.InboundMessage) specCodingTaskRef { @@ -1372,6 +1519,14 @@ func (al *AgentLoop) GetSessionHistory(sessionKey string) []providers.Message { return al.sessions.GetHistory(sessionKey) } +func (al *AgentLoop) GetSessionHistoryWindow(sessionKey string, around, before, after, limit int) []providers.Message { + return al.sessions.GetHistoryWindow(sessionKey, around, before, after, limit) +} + +func (al *AgentLoop) SearchSessions(query string, kinds []string, excludeKey string, limit int) []session.SessionSearchResult { + return al.sessions.Search(query, kinds, excludeKey, limit) +} + func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { if msg.SessionKey == "" { msg.SessionKey = "main" @@ -1411,9 +1566,20 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) logDirectResponse: true, }) if err != nil { + tools.RecordSubagentExecutionStats(ctx, tools.SubagentExecutionStats{ + Iterations: loopResult.iteration, + Attempts: loopResult.attemptCount, + Restarts: loopResult.restartCount, + FailureCode: classifyLLMFailureCode(err), + }) al.reopenSpecTaskOnError(specTaskRef, msg, err) return "", err } + tools.RecordSubagentExecutionStats(ctx, tools.SubagentExecutionStats{ + Iterations: loopResult.iteration, + Attempts: loopResult.attemptCount, + Restarts: loopResult.restartCount, + }) finalContent, userContent := al.finalizeUserTurnResponse(ctx, msg, responseLang, loopResult) // Log response preview (original content) @@ -1940,32 +2106,215 @@ func isRemoteReference(ref string) bool { strings.HasPrefix(trimmed, "data:") } -// GetStartupInfo returns information about loaded tools and skills for logging. -func (al *AgentLoop) compactSessionIfNeeded(sessionKey string) { - if !al.compactionEnabled { +func (al *AgentLoop) ensureCompactionRunnerStarted() { + if al == nil || !al.compactionEnabled || al.compactionRunner == nil { return } + al.compactionRunner.Start(func(stop <-chan struct{}) { + al.runCompactionLoop(stop) + }) +} + +func (al *AgentLoop) signalCompactionWorker() { + if al == nil || al.compactionSignal == nil { + return + } + select { + case al.compactionSignal <- struct{}{}: + default: + } +} + +func (al *AgentLoop) enqueueSessionCompaction(sessionKey string) { + if al == nil || !al.compactionEnabled { + return + } + key := strings.TrimSpace(sessionKey) + if key == "" { + return + } + al.compactionMu.Lock() + if _, ok := al.compactionInflight[key]; ok { + al.compactionDirty[key] = struct{}{} + al.compactionMu.Unlock() + al.signalCompactionWorker() + return + } + if _, ok := al.compactionQueued[key]; ok { + al.compactionMu.Unlock() + al.signalCompactionWorker() + return + } + al.compactionQueue = append(al.compactionQueue, key) + al.compactionQueued[key] = struct{}{} + al.compactionMu.Unlock() + al.ensureCompactionRunnerStarted() + al.signalCompactionWorker() +} + +func (al *AgentLoop) dequeueSessionCompaction() (string, bool) { + if al == nil { + return "", false + } + al.compactionMu.Lock() + defer al.compactionMu.Unlock() + if len(al.compactionQueue) == 0 { + return "", false + } + key := al.compactionQueue[0] + al.compactionQueue = al.compactionQueue[1:] + delete(al.compactionQueued, key) + al.compactionInflight[key] = struct{}{} + return key, true +} + +func (al *AgentLoop) finishSessionCompaction(sessionKey string, retry bool) { + if al == nil { + return + } + key := strings.TrimSpace(sessionKey) + if key == "" { + return + } + shouldWake := false + al.compactionMu.Lock() + delete(al.compactionInflight, key) + if _, dirty := al.compactionDirty[key]; dirty { + delete(al.compactionDirty, key) + retry = true + } + if retry { + if _, queued := al.compactionQueued[key]; !queued { + al.compactionQueue = append(al.compactionQueue, key) + al.compactionQueued[key] = struct{}{} + shouldWake = true + } + } + al.compactionMu.Unlock() + if shouldWake { + al.signalCompactionWorker() + } +} + +func (al *AgentLoop) runCompactionLoop(stop <-chan struct{}) { + if al == nil { + return + } + for { + key, ok := al.dequeueSessionCompaction() + if !ok { + select { + case <-stop: + return + case <-al.compactionSignal: + continue + } + } + retry := al.runSessionCompactionJob(stop, key) + al.finishSessionCompaction(key, retry) + } +} + +func (al *AgentLoop) runSessionCompactionJob(stop <-chan struct{}, sessionKey string) bool { + if al == nil || strings.TrimSpace(sessionKey) == "" { + return false + } + jobCtx, cancel := al.newBackgroundCompactionContext(stop, sessionKey) + defer cancel() + applied, changed, needsRetry := al.compactSessionIfNeeded(jobCtx, sessionKey) + if changed && !applied { + return true + } + return needsRetry +} + +func (al *AgentLoop) newBackgroundCompactionContext(stop <-chan struct{}, sessionKey string) (context.Context, context.CancelFunc) { + base, baseCancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + select { + case <-stop: + baseCancel() + case <-done: + } + }() + timeout := al.compactionTimeoutForSession(sessionKey) + ctx, timeoutCancel := context.WithTimeout(base, timeout) + return ctx, func() { + close(done) + timeoutCancel() + baseCancel() + } +} + +func (al *AgentLoop) compactionTimeoutForSession(sessionKey string) time.Duration { + timeout := 30 * time.Second + if al != nil && al.cfg != nil { + name := strings.TrimSpace(al.getSessionProvider(sessionKey)) + if name == "" && len(al.providerNames) > 0 { + name = al.providerNames[0] + } + if pc, ok := config.ProviderConfigByName(al.cfg, name); ok && pc.TimeoutSec > 0 { + providerTimeout := time.Duration(pc.TimeoutSec) * time.Second + if providerTimeout < timeout { + timeout = providerTimeout + } + } + } + if timeout < 5*time.Second { + timeout = 5 * time.Second + } + return timeout +} + +func (al *AgentLoop) shouldCompactSnapshot(ctx context.Context, sessionKey string, snapshot session.SessionCompactionSnapshot) (bool, int, float64) { trigger := al.compactionTrigger if trigger <= 0 { trigger = 60 } - keepRecent := al.compactionKeepRecent - if keepRecent <= 0 || keepRecent >= trigger { - keepRecent = trigger / 2 - if keepRecent < 10 { - keepRecent = 10 - } + pressure := al.estimateCompactionPressure(ctx, sessionKey, snapshot.History, snapshot.Summary) + if len(snapshot.History) <= trigger && pressure < al.compactionPressureThreshold { + return false, trigger, pressure } - h := al.sessions.GetHistory(sessionKey) - if len(h) <= trigger { - return + return true, trigger, pressure +} + +// GetStartupInfo returns information about loaded tools and skills for logging. +func (al *AgentLoop) compactSessionIfNeeded(ctx context.Context, sessionKey string) (applied bool, changed bool, needsRetry bool) { + if !al.compactionEnabled { + return false, false, false } - removed := len(h) - keepRecent - tpl := "[runtime-compaction] removed %d old messages, kept %d recent messages" - note := fmt.Sprintf(tpl, removed, keepRecent) - if al.sessions.CompactSession(sessionKey, keepRecent, note) { - _ = al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) + if ctx == nil { + ctx = context.Background() } + snapshot := al.sessions.CompactionSnapshot(sessionKey) + if len(snapshot.History) == 0 { + return false, false, false + } + shouldCompact, trigger, _ := al.shouldCompactSnapshot(ctx, sessionKey, snapshot) + if !shouldCompact { + return false, false, false + } + keepRecent := al.compactionKeepCount(len(snapshot.History), trigger) + if keepRecent >= len(snapshot.History) { + return false, false, false + } + removed := len(snapshot.History) - keepRecent + older := append([]providers.Message(nil), snapshot.History[:removed]...) + recent := append([]providers.Message(nil), snapshot.History[removed:]...) + summary := al.buildCompactionSummary(ctx, sessionKey, older, snapshot.Summary) + if ctx.Err() != nil { + return false, false, true + } + if summary == "" { + summary = fmt.Sprintf("Key Facts\n- Runtime compaction removed %d older messages.\n\nDecisions\n\nOpen Items\n\nNext Steps", removed) + } + if al.sessions.ApplyCompactionIfUnchanged(sessionKey, snapshot.NextSeq, snapshot.Summary, recent, summary) { + return true, false, false + } + next := al.sessions.CompactionSnapshot(sessionKey) + shouldRetry, _, _ := al.shouldCompactSnapshot(context.Background(), sessionKey, next) + return false, true, shouldRetry } // RunStartupSelfCheckAllSessions runs startup compaction checks across loaded sessions. @@ -1975,35 +2324,16 @@ func (al *AgentLoop) RunStartupSelfCheckAllSessions(ctx context.Context) Startup return report } - trigger := al.compactionTrigger - if trigger <= 0 { - trigger = 60 - } - keepRecent := al.compactionKeepRecent - if keepRecent <= 0 || keepRecent >= trigger { - keepRecent = trigger / 2 - if keepRecent < 10 { - keepRecent = 10 - } - } - for _, key := range al.sessions.Keys() { select { case <-ctx.Done(): return report default: } - - history := al.sessions.GetHistory(key) - if len(history) <= trigger { - continue - } - - removed := len(history) - keepRecent - tpl := "[startup-compaction] removed %d old messages, kept %d recent messages" - note := fmt.Sprintf(tpl, removed, keepRecent) - if al.sessions.CompactSession(key, keepRecent, note) { - al.sessions.Save(al.sessions.GetOrCreate(key)) + jobCtx, cancel := context.WithTimeout(ctx, al.compactionTimeoutForSession(key)) + applied, _, _ := al.compactSessionIfNeeded(jobCtx, key) + cancel() + if applied { report.CompactedSessions++ } } @@ -2011,6 +2341,201 @@ func (al *AgentLoop) RunStartupSelfCheckAllSessions(ctx context.Context) Startup return report } +func (al *AgentLoop) buildCompactionSummary(ctx context.Context, sessionKey string, older []providers.Message, existingSummary string) string { + if al == nil || al.provider == nil || len(older) == 0 { + return strings.TrimSpace(existingSummary) + } + pruned := pruneCompactionMessages(older, al.compactionMaxTranscriptChars) + if compactor, ok := al.provider.(providers.ResponsesCompactor); ok && compactor.SupportsResponsesCompact() { + if summary, err := compactor.BuildSummaryViaResponsesCompact(ctx, al.model, existingSummary, pruned, al.compactionMaxSummaryChars); err == nil { + return strings.TrimSpace(summary) + } + } + + payload, err := json.Marshal(compactionTranscript(pruned)) + if err != nil { + return strings.TrimSpace(existingSummary) + } + resp, err := al.provider.Chat(ctx, []providers.Message{ + { + Role: "system", + Content: "You are compacting a previous conversation window for a future handoff. " + + "Return concise markdown with exactly these sections: Key Facts, Decisions, Open Items, Next Steps. " + + "Preserve concrete decisions, file paths, errors, and pending work. Do not address the user.", + }, + { + Role: "user", + Content: "Existing summary:\n" + strings.TrimSpace(existingSummary) + "\n\nEarlier conversation JSON:\n" + string(payload), + }, + }, nil, al.model, map[string]interface{}{"max_tokens": 900}) + if err != nil || resp == nil { + return strings.TrimSpace(existingSummary) + } + out := strings.TrimSpace(resp.Content) + if al.compactionMaxSummaryChars > 0 && len(out) > al.compactionMaxSummaryChars { + out = out[:al.compactionMaxSummaryChars] + } + return out +} + +func pruneCompactionMessages(messages []providers.Message, maxTranscriptChars int) []providers.Message { + out := make([]providers.Message, 0, len(messages)) + remaining := maxTranscriptChars + for _, msg := range messages { + cp := msg + if strings.EqualFold(strings.TrimSpace(cp.Role), "tool") && len(cp.Content) > 400 { + cp.Content = "[older tool output trimmed during context compaction]" + } else if len(cp.Content) > 2000 { + cp.Content = cp.Content[:2000] + "\n...[truncated for compaction]..." + } + if remaining > 0 && len(cp.Content) > remaining { + cp.Content = strings.TrimSpace(cp.Content[:remaining]) + "\n...[truncated for compaction budget]..." + } + if remaining > 0 { + remaining -= len(cp.Content) + if remaining <= 0 { + out = append(out, cp) + break + } + } + out = append(out, cp) + } + return out +} + +func compactionTranscript(messages []providers.Message) []map[string]interface{} { + out := make([]map[string]interface{}, 0, len(messages)) + for _, msg := range messages { + entry := map[string]interface{}{ + "role": msg.Role, + "content": strings.TrimSpace(msg.Content), + } + if strings.TrimSpace(msg.ToolCallID) != "" { + entry["tool_call_id"] = msg.ToolCallID + } + if len(msg.ToolCalls) > 0 { + entry["tool_calls"] = msg.ToolCalls + } + out = append(out, entry) + } + return out +} + +func (al *AgentLoop) compactionKeepCount(historyLen, trigger int) int { + if historyLen <= 0 { + return 0 + } + keep := al.compactionProtectLastN + if keep <= 0 { + keep = al.compactionKeepRecent + } + if keep <= 0 { + keep = 12 + } + if ratio := al.compactionTargetRatio; ratio > 0 { + ratioKeep := int(math.Ceil(float64(historyLen) * ratio)) + if ratioKeep > keep { + keep = ratioKeep + } + } + if trigger > 0 && keep >= trigger { + keep = maxLoopInt(trigger-1, al.compactionProtectLastN, 1) + } + if keep >= historyLen { + keep = historyLen - 1 + } + if keep < 1 { + keep = 1 + } + return keep +} + +func (al *AgentLoop) estimateCompactionPressure(ctx context.Context, sessionKey string, history []providers.Message, summary string) float64 { + maxTokens := al.maxTokensForSession(sessionKey) + if maxTokens <= 0 { + return 0 + } + estimateMessages := make([]providers.Message, 0, len(history)+1) + estimateMessages = append(estimateMessages, history...) + if strings.TrimSpace(summary) != "" { + estimateMessages = append([]providers.Message{{ + Role: "system", + Content: "Handoff summary:\n" + strings.TrimSpace(summary), + }}, estimateMessages...) + } + if activeProvider, activeModel, _, err := al.activeProviderForSession(sessionKey); err == nil { + if counter, ok := activeProvider.(providers.TokenCounter); ok { + if usage, err := counter.CountTokens(ctx, estimateMessages, nil, activeModel, nil); err == nil && usage != nil { + total := usage.TotalTokens + if total <= 0 { + total = usage.PromptTokens + } + if total > 0 { + return float64(total) / float64(maxTokens) + } + } + } + } + charCount := 0 + for _, msg := range estimateMessages { + charCount += len(msg.Content) + } + approxTokens := charCount / 4 + if approxTokens <= 0 && charCount > 0 { + approxTokens = 1 + } + return float64(approxTokens) / float64(maxTokens) +} + +func normalizePositiveInt(value, fallback int) int { + if value > 0 { + return value + } + return fallback +} + +func normalizeCompactionRatio(value, fallback float64) float64 { + if value > 0 && value <= 1 { + return value + } + return fallback +} + +func maxLoopInt(values ...int) int { + best := 0 + for _, value := range values { + if value > best { + best = value + } + } + return best +} + +func classifyLLMFailureCode(err error) string { + if err == nil { + return "" + } + if errors.Is(err, context.DeadlineExceeded) { + return "timeout" + } + if code := providers.ExecutionErrorCode(err); strings.TrimSpace(code) != "" { + return strings.TrimSpace(code) + } + lower := strings.ToLower(strings.TrimSpace(err.Error())) + switch { + case strings.Contains(lower, "max tool iterations"): + return "retry_limit" + case strings.Contains(lower, "stream stale"): + return "stream_stale" + case strings.Contains(lower, "stream failed"): + return "stream_failed" + case strings.Contains(lower, "thinking budget exhausted"), strings.Contains(lower, "continuation exhausted"): + return "continuation_exhausted" + default: + return "" + } +} + func (al *AgentLoop) GetStartupInfo() map[string]interface{} { info := make(map[string]interface{}) diff --git a/pkg/agent/loop_codex_options_test.go b/pkg/agent/loop_codex_options_test.go index 517f302..348daa7 100644 --- a/pkg/agent/loop_codex_options_test.go +++ b/pkg/agent/loop_codex_options_test.go @@ -193,10 +193,13 @@ func TestTryFallbackProvidersUsesFallbackProviderOptionsAndPersistsSelection(t * }, } - resp, providerName, err := loop.tryFallbackProviders(context.Background(), bus.InboundMessage{SessionKey: "chat-1"}, nil, nil, errors.New("primary failed")) + resp, providerName, attempts, err := loop.tryFallbackProviders(context.Background(), bus.InboundMessage{SessionKey: "chat-1"}, nil, nil, errors.New("primary failed")) if err != nil { t.Fatalf("expected fallback success, got %v", err) } + if attempts != 1 { + t.Fatalf("expected one fallback attempt, got %d", attempts) + } if resp == nil || resp.Content != "fallback" { t.Fatalf("unexpected fallback response: %#v", resp) } diff --git a/pkg/agent/reliability_test.go b/pkg/agent/reliability_test.go new file mode 100644 index 0000000..615ffdc --- /dev/null +++ b/pkg/agent/reliability_test.go @@ -0,0 +1,348 @@ +package agent + +import ( + "context" + "strings" + "sync" + "testing" + "time" + + "github.com/YspCoder/clawgo/pkg/bus" + "github.com/YspCoder/clawgo/pkg/lifecycle" + "github.com/YspCoder/clawgo/pkg/providers" + "github.com/YspCoder/clawgo/pkg/session" + toolspkg "github.com/YspCoder/clawgo/pkg/tools" +) + +type pressureProvider struct { + tokens int +} + +func (p *pressureProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + return &providers.LLMResponse{Content: "Key Facts\n- compacted", FinishReason: "stop"}, nil +} + +func (p *pressureProvider) GetDefaultModel() string { return "pressure-model" } + +func (p *pressureProvider) CountTokens(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.UsageInfo, error) { + return &providers.UsageInfo{PromptTokens: p.tokens, TotalTokens: p.tokens}, nil +} + +type fallbackStreamingProvider struct { + stream func(ctx context.Context, onDelta func(string)) (*providers.LLMResponse, error) + chat func(ctx context.Context) (*providers.LLMResponse, error) +} + +func (p *fallbackStreamingProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + return p.chat(ctx) +} + +func (p *fallbackStreamingProvider) ChatStream(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*providers.LLMResponse, error) { + return p.stream(ctx, onDelta) +} + +func (p *fallbackStreamingProvider) GetDefaultModel() string { return "stream-model" } + +type asyncCompactionProvider struct { + mu sync.Mutex + started chan int + release chan struct{} + finished chan int + calls int +} + +func (p *asyncCompactionProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + p.mu.Lock() + p.calls++ + call := p.calls + p.mu.Unlock() + if p.started != nil { + p.started <- call + } + if p.release != nil { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-p.release: + } + } + if p.finished != nil { + p.finished <- call + } + return &providers.LLMResponse{Content: "Key Facts\n- compacted", FinishReason: "stop"}, nil +} + +func (p *asyncCompactionProvider) GetDefaultModel() string { return "async-model" } + +func TestCompactSessionTriggeredByTokenPressure(t *testing.T) { + t.Parallel() + + sm := session.NewSessionManager(t.TempDir()) + key := "cli:pressure" + for _, content := range []string{"one", "two", "three", "four", "five", "six"} { + sm.AddMessage(key, "user", content) + } + provider := &pressureProvider{tokens: 900} + loop := &AgentLoop{ + provider: provider, + model: provider.GetDefaultModel(), + maxTokens: 1000, + providerNames: []string{"pressure"}, + sessions: sm, + compactionEnabled: true, + compactionTrigger: 100, + compactionProtectLastN: 2, + compactionKeepRecent: 2, + compactionTargetRatio: 0.35, + compactionPressureThreshold: 0.8, + compactionMaxSummaryChars: 6000, + compactionMaxTranscriptChars: 20000, + } + + applied, _, _ := loop.compactSessionIfNeeded(context.Background(), key) + if !applied { + t.Fatal("expected compaction to apply") + } + + history := sm.GetPromptHistory(key) + if len(history) != 3 { + t.Fatalf("expected ratio-based keep count 3, got %d", len(history)) + } + if history[0].Content != "four" || history[2].Content != "six" { + t.Fatalf("expected tail messages preserved, got %#v", history) + } + if summary := sm.GetSummary(key); summary == "" { + t.Fatal("expected compaction summary to be written") + } +} + +func TestFinalizeUserMessageDoesNotWaitForCompaction(t *testing.T) { + t.Parallel() + + sm := session.NewSessionManager(t.TempDir()) + key := "cli:async" + for _, content := range []string{"one", "two", "three", "four", "five", "six"} { + sm.AddMessage(key, "user", content) + } + provider := &asyncCompactionProvider{ + started: make(chan int, 2), + release: make(chan struct{}), + } + loop := &AgentLoop{ + provider: provider, + model: provider.GetDefaultModel(), + sessions: sm, + compactionEnabled: true, + compactionTrigger: 4, + compactionProtectLastN: 2, + compactionKeepRecent: 2, + compactionTargetRatio: 0.35, + compactionPressureThreshold: 0.1, + compactionMaxSummaryChars: 6000, + compactionMaxTranscriptChars: 20000, + compactionRunner: lifecycle.NewLoopRunner(), + compactionSignal: make(chan struct{}, 1), + compactionQueued: map[string]struct{}{}, + compactionInflight: map[string]struct{}{}, + compactionDirty: map[string]struct{}{}, + } + t.Cleanup(loop.Stop) + + start := time.Now() + loop.finalizeUserMessage(key, "en", nil, "final") + if elapsed := time.Since(start); elapsed > 150*time.Millisecond { + t.Fatalf("expected finalizeUserMessage to return quickly, took %s", elapsed) + } + select { + case <-provider.started: + case <-time.After(500 * time.Millisecond): + t.Fatal("expected async compaction to start in background") + } + close(provider.release) + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if summary := sm.GetSummary(key); summary != "" { + return + } + time.Sleep(20 * time.Millisecond) + } + t.Fatal("expected async compaction summary to be written") +} + +func TestCompactionWorkerRetriesDirtySessionWithoutLosingNewMessages(t *testing.T) { + t.Parallel() + + sm := session.NewSessionManager(t.TempDir()) + key := "cli:dirty" + for _, content := range []string{"one", "two", "three", "four", "five", "six"} { + sm.AddMessage(key, "user", content) + } + provider := &asyncCompactionProvider{ + started: make(chan int, 4), + release: make(chan struct{}, 4), + finished: make(chan int, 4), + } + loop := &AgentLoop{ + provider: provider, + model: provider.GetDefaultModel(), + sessions: sm, + compactionEnabled: true, + compactionTrigger: 4, + compactionProtectLastN: 2, + compactionKeepRecent: 2, + compactionTargetRatio: 0.35, + compactionPressureThreshold: 0.1, + compactionMaxSummaryChars: 6000, + compactionMaxTranscriptChars: 20000, + compactionRunner: lifecycle.NewLoopRunner(), + compactionSignal: make(chan struct{}, 1), + compactionQueued: map[string]struct{}{}, + compactionInflight: map[string]struct{}{}, + compactionDirty: map[string]struct{}{}, + } + t.Cleanup(loop.Stop) + + loop.enqueueSessionCompaction(key) + select { + case <-provider.started: + case <-time.After(time.Second): + t.Fatal("expected first compaction run to start") + } + + sm.AddMessage(key, "assistant", "seven") + loop.enqueueSessionCompaction(key) + provider.release <- struct{}{} + + select { + case <-provider.finished: + case <-time.After(time.Second): + t.Fatal("expected first compaction run to finish") + } + select { + case call := <-provider.started: + if call != 2 { + t.Fatalf("expected second compaction attempt after dirty retry, got call %d", call) + } + case <-time.After(time.Second): + t.Fatal("expected dirty session to trigger a second compaction run") + } + provider.release <- struct{}{} + select { + case <-provider.finished: + case <-time.After(time.Second): + t.Fatal("expected second compaction run to finish") + } + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + history := sm.GetPromptHistory(key) + if len(history) > 0 && history[len(history)-1].Content == "seven" && sm.GetSummary(key) != "" { + return + } + time.Sleep(20 * time.Millisecond) + } + t.Fatal("expected retried compaction to preserve new message and summary") +} + +func TestRequestStreamingLLMResponseFallsBackBeforeFirstDelta(t *testing.T) { + t.Parallel() + + provider := &fallbackStreamingProvider{ + stream: func(ctx context.Context, onDelta func(string)) (*providers.LLMResponse, error) { + <-ctx.Done() + return nil, providers.NewProviderExecutionError("stream_stale", "stream stale", "stream", true, "test") + }, + chat: func(ctx context.Context) (*providers.LLMResponse, error) { + return &providers.LLMResponse{Content: "fallback", FinishReason: "stop"}, nil + }, + } + loop := &AgentLoop{ + bus: bus.NewMessageBus(), + sessionStreamed: map[string]bool{}, + } + ctx, cancel := context.WithTimeout(context.Background(), 400*time.Millisecond) + defer cancel() + + resp, attempts, err := loop.requestStreamingLLMResponse(llmTurnLoopConfig{ + ctx: ctx, + sessionKey: "cli:test", + toolChannel: "telegram", + toolChatID: "chat", + enableStreaming: true, + }, provider, provider, provider.GetDefaultModel(), []providers.Message{{Role: "user", Content: "hello"}}, nil, nil) + if err != nil { + t.Fatalf("expected fallback success, got %v", err) + } + if attempts != 2 { + t.Fatalf("expected streaming + fallback attempts, got %d", attempts) + } + if resp == nil || resp.Content != "fallback" { + t.Fatalf("unexpected fallback response: %#v", resp) + } +} + +func TestRequestStreamingLLMResponseDoesNotFallbackAfterDelta(t *testing.T) { + t.Parallel() + + provider := &fallbackStreamingProvider{ + stream: func(ctx context.Context, onDelta func(string)) (*providers.LLMResponse, error) { + onDelta("partial") + return nil, providers.NewProviderExecutionError("stream_failed", "stream failed", "stream", true, "test") + }, + chat: func(ctx context.Context) (*providers.LLMResponse, error) { + return &providers.LLMResponse{Content: "fallback", FinishReason: "stop"}, nil + }, + } + loop := &AgentLoop{ + bus: bus.NewMessageBus(), + sessionStreamed: map[string]bool{}, + } + + resp, attempts, err := loop.requestStreamingLLMResponse(llmTurnLoopConfig{ + ctx: context.Background(), + sessionKey: "cli:test", + toolChannel: "telegram", + toolChatID: "chat", + enableStreaming: true, + }, provider, provider, provider.GetDefaultModel(), []providers.Message{{Role: "user", Content: "hello"}}, nil, nil) + if err == nil { + t.Fatal("expected stream failure without fallback") + } + if attempts != 1 { + t.Fatalf("expected single streaming attempt, got %d", attempts) + } + if resp != nil { + t.Fatalf("expected nil response on post-delta stream failure, got %#v", resp) + } +} + +func TestRunLLMTurnLoopReturnsRetryLimitError(t *testing.T) { + t.Parallel() + + provider := &sequenceProvider{ + responses: []*providers.LLMResponse{{ + Content: "", + ToolCalls: []providers.ToolCall{ + {ID: "tool-1", Name: "system_info", Arguments: map[string]interface{}{}}, + }, + FinishReason: "tool_calls", + }}, + } + loop := &AgentLoop{ + provider: provider, + model: provider.GetDefaultModel(), + maxIterations: 1, + tools: toolspkg.NewToolRegistry(), + providerNames: []string{"sequence"}, + sessionProvider: map[string]string{}, + } + loop.tools.Register(toolspkg.NewSystemInfoTool()) + _, err := loop.runLLMTurnLoop(llmTurnLoopConfig{ + ctx: context.Background(), + sessionKey: "cli:test", + messages: []providers.Message{{Role: "user", Content: "hello"}}, + }) + if err == nil || !strings.Contains(err.Error(), "max tool iterations exceeded") { + t.Fatalf("expected retry limit error, got %v", err) + } +} diff --git a/pkg/api/server.go b/pkg/api/server.go index b17f9a9..6d51269 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -47,7 +47,8 @@ type Server struct { workspacePath string logFilePath string onChat func(ctx context.Context, sessionKey, content string) (string, error) - onChatHistory func(sessionKey string) []map[string]interface{} + onChatHistory func(query ChatHistoryQuery) []map[string]interface{} + onSessionSearch func(query SessionSearchQuery) []map[string]interface{} onConfigAfter func(forceRuntimeReload bool) error onCron func(action string, args map[string]interface{}) (interface{}, error) onToolsCatalog func() interface{} @@ -70,6 +71,22 @@ type channelDraftStore struct { weixinRuntime *channels.WeixinChannel } +type ChatHistoryQuery struct { + Session string + Around int + Before int + After int + Limit int +} + +type SessionSearchQuery struct { + Query string + Limit int + Kinds []string + ExcludeCurrent bool + Session string +} + func NewServer(host string, port int, token string) *Server { addr := strings.TrimSpace(host) if addr == "" { @@ -94,9 +111,12 @@ func (s *Server) SetToken(token string) { s.token = strings.TrimSpace(tok func (s *Server) SetChatHandler(fn func(ctx context.Context, sessionKey, content string) (string, error)) { s.onChat = fn } -func (s *Server) SetChatHistoryHandler(fn func(sessionKey string) []map[string]interface{}) { +func (s *Server) SetChatHistoryHandler(fn func(query ChatHistoryQuery) []map[string]interface{}) { s.onChatHistory = fn } +func (s *Server) SetSessionSearchHandler(fn func(query SessionSearchQuery) []map[string]interface{}) { + s.onSessionSearch = fn +} func (s *Server) SetMessageBus(mb *bus.MessageBus) { s.messageBus = mb } func (s *Server) SetConfigAfterHook(fn func(forceRuntimeReload bool) error) { s.onConfigAfter = fn } func (s *Server) SetCronHandler(fn func(action string, args map[string]interface{}) (interface{}, error)) { @@ -611,6 +631,7 @@ func (s *Server) Start(ctx context.Context) error { mux.HandleFunc("/api/cron", s.handleWebUICron) mux.HandleFunc("/api/skills", s.handleWebUISkills) mux.HandleFunc("/api/sessions", s.handleWebUISessions) + mux.HandleFunc("/api/sessions/search", s.handleWebUISessionSearch) mux.HandleFunc("/api/memory", s.handleWebUIMemory) mux.HandleFunc("/api/workspace_file", s.handleWebUIWorkspaceFile) mux.HandleFunc("/api/workspace_docs", s.handleWebUIWorkspaceDocs) @@ -1516,7 +1537,14 @@ func (s *Server) handleWebUIChatHistory(w http.ResponseWriter, r *http.Request) writeJSON(w, map[string]interface{}{"ok": true, "session": session, "messages": []interface{}{}}) return } - writeJSON(w, map[string]interface{}{"ok": true, "session": session, "messages": s.onChatHistory(session)}) + query := ChatHistoryQuery{ + Session: session, + Around: queryBoundedPositiveInt(r, "around", 0, 1_000_000), + Before: queryBoundedPositiveInt(r, "before", 0, 1_000_000), + After: queryBoundedPositiveInt(r, "after", 0, 1_000_000), + Limit: queryBoundedPositiveInt(r, "limit", 200, 2000), + } + writeJSON(w, map[string]interface{}{"ok": true, "session": session, "messages": s.onChatHistory(query)}) } func (s *Server) handleWebUIChatLive(w http.ResponseWriter, r *http.Request) { @@ -3349,10 +3377,25 @@ func (s *Server) handleWebUISessions(w http.ResponseWriter, r *http.Request) { continue } name := e.Name() - if !strings.HasSuffix(name, ".jsonl") || strings.Contains(name, ".deleted.") { + key := "" + switch { + case strings.HasSuffix(name, ".meta.json"): + key = strings.TrimSuffix(name, ".meta.json") + case strings.HasSuffix(name, ".active.jsonl"): + key = strings.TrimSuffix(name, ".active.jsonl") + case strings.HasSuffix(name, ".jsonl") && !strings.Contains(name, ".deleted."): + key = strings.TrimSuffix(name, ".jsonl") + if strings.HasSuffix(key, ".active") { + key = strings.TrimSuffix(key, ".active") + } + if idx := strings.LastIndex(key, "."); idx > 0 { + if seqPart := key[idx+1:]; len(seqPart) == 4 && regexp.MustCompile(`^\d{4}$`).MatchString(seqPart) { + key = key[:idx] + } + } + default: continue } - key := strings.TrimSuffix(name, ".jsonl") if strings.TrimSpace(key) == "" { continue } @@ -3376,6 +3419,65 @@ func (s *Server) handleWebUISessions(w http.ResponseWriter, r *http.Request) { writeJSON(w, map[string]interface{}{"ok": true, "sessions": out}) } +func (s *Server) handleWebUISessionSearch(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if s.onSessionSearch == nil { + writeJSON(w, map[string]interface{}{"ok": true, "results": []interface{}{}}) + return + } + queryText := strings.TrimSpace(r.URL.Query().Get("query")) + if queryText == "" { + writeJSON(w, map[string]interface{}{"ok": true, "results": []interface{}{}}) + return + } + kinds := splitCSVQueryParam(r.URL.Query()["kinds"]) + if len(kinds) == 0 { + kinds = splitCSVQueryParam([]string{r.URL.Query().Get("kind")}) + } + excludeCurrent := false + if raw := strings.TrimSpace(r.URL.Query().Get("exclude_current")); raw != "" { + excludeCurrent = raw == "1" || strings.EqualFold(raw, "true") || strings.EqualFold(raw, "yes") + } + query := SessionSearchQuery{ + Query: queryText, + Limit: queryBoundedPositiveInt(r, "limit", 5, 100), + Kinds: kinds, + ExcludeCurrent: excludeCurrent, + Session: strings.TrimSpace(r.URL.Query().Get("session")), + } + writeJSON(w, map[string]interface{}{ + "ok": true, + "query": query.Query, + "results": s.onSessionSearch(query), + }) +} + +func splitCSVQueryParam(values []string) []string { + out := make([]string, 0, len(values)) + seen := map[string]struct{}{} + for _, value := range values { + for _, item := range strings.Split(value, ",") { + item = strings.TrimSpace(item) + if item == "" { + continue + } + if _, ok := seen[item]; ok { + continue + } + seen[item] = struct{}{} + out = append(out, item) + } + } + return out +} + func isUserFacingSessionKey(key string) bool { k := strings.ToLower(strings.TrimSpace(key)) if k == "" { diff --git a/pkg/api/server_test.go b/pkg/api/server_test.go index 3ac1f9e..4b4f6b2 100644 --- a/pkg/api/server_test.go +++ b/pkg/api/server_test.go @@ -408,6 +408,73 @@ func TestHandleWebUISessionsHidesInternalSessionsByDefault(t *testing.T) { } } +func TestHandleWebUIChatHistorySupportsWindowQuery(t *testing.T) { + t.Parallel() + + srv := NewServer("127.0.0.1", 0, "") + var got ChatHistoryQuery + srv.SetChatHistoryHandler(func(query ChatHistoryQuery) []map[string]interface{} { + got = query + return []map[string]interface{}{{"role": "assistant", "content": "ok"}} + }) + + req := httptest.NewRequest(http.MethodGet, "/api/chat/history?session=alpha&after=2&limit=3", nil) + rec := httptest.NewRecorder() + srv.handleWebUIChatHistory(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + if got.Session != "alpha" || got.After != 2 || got.Limit != 3 { + t.Fatalf("unexpected query: %+v", got) + } +} + +func TestHandleWebUISessionSearchReturnsResults(t *testing.T) { + t.Parallel() + + srv := NewServer("127.0.0.1", 0, "") + var got SessionSearchQuery + srv.SetSessionSearchHandler(func(query SessionSearchQuery) []map[string]interface{} { + got = query + return []map[string]interface{}{ + { + "key": "main", + "kind": "main", + "updated_at": int64(123), + "summary": "deploy notes", + "score": 2, + "snippets": []map[string]interface{}{{"seq": 3, "content": "deploy timeout"}}, + }, + } + }) + + req := httptest.NewRequest(http.MethodGet, "/api/sessions/search?query=deploy&kinds=main,cron&exclude_current=1&session=current&limit=7", nil) + rec := httptest.NewRecorder() + srv.handleWebUISessionSearch(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + if got.Query != "deploy" || got.Session != "current" || !got.ExcludeCurrent || got.Limit != 7 { + t.Fatalf("unexpected search query: %+v", got) + } + if len(got.Kinds) != 2 || got.Kinds[0] != "main" || got.Kinds[1] != "cron" { + t.Fatalf("unexpected kinds: %+v", got.Kinds) + } + + var payload struct { + OK bool `json:"ok"` + Results []map[string]interface{} `json:"results"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode payload: %v", err) + } + if !payload.OK || len(payload.Results) != 1 { + t.Fatalf("unexpected payload: %+v", payload) + } +} + func TestSaveProviderConfigForcesRuntimeReload(t *testing.T) { t.Parallel() diff --git a/pkg/config/config.go b/pkg/config/config.go index cd729af..61e26f4 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -106,15 +106,16 @@ type SubagentToolsConfig struct { } type SubagentRuntimeConfig struct { - Provider string `json:"provider,omitempty"` - Model string `json:"model,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TimeoutSec int `json:"timeout_sec,omitempty"` - MaxRetries int `json:"max_retries,omitempty"` - RetryBackoffMs int `json:"retry_backoff_ms,omitempty"` - MaxTaskChars int `json:"max_task_chars,omitempty"` - MaxResultChars int `json:"max_result_chars,omitempty"` - MaxParallelRuns int `json:"max_parallel_runs,omitempty"` + Provider string `json:"provider,omitempty"` + Model string `json:"model,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TimeoutSec int `json:"timeout_sec,omitempty"` + MaxRetries int `json:"max_retries,omitempty"` + RetryBackoffMs int `json:"retry_backoff_ms,omitempty"` + MaxTaskChars int `json:"max_task_chars,omitempty"` + MaxResultChars int `json:"max_result_chars,omitempty"` + MaxParallelRuns int `json:"max_parallel_runs,omitempty"` + MaxToolIterations int `json:"max_tool_iterations,omitempty"` } type AgentDefaults struct { @@ -158,12 +159,15 @@ type SystemSummaryPolicyConfig struct { } type ContextCompactionConfig struct { - Enabled bool `json:"enabled" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_ENABLED"` - Mode string `json:"mode" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_MODE"` - TriggerMessages int `json:"trigger_messages" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_TRIGGER_MESSAGES"` - KeepRecentMessages int `json:"keep_recent_messages" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_KEEP_RECENT_MESSAGES"` - MaxSummaryChars int `json:"max_summary_chars" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_MAX_SUMMARY_CHARS"` - MaxTranscriptChars int `json:"max_transcript_chars" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_MAX_TRANSCRIPT_CHARS"` + Enabled bool `json:"enabled" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_ENABLED"` + Mode string `json:"mode" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_MODE"` + TriggerMessages int `json:"trigger_messages" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_TRIGGER_MESSAGES"` + KeepRecentMessages int `json:"keep_recent_messages" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_KEEP_RECENT_MESSAGES"` + TargetRatio float64 `json:"target_ratio" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_TARGET_RATIO"` + ProtectLastN int `json:"protect_last_n" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_PROTECT_LAST_N"` + PressureThreshold float64 `json:"pressure_threshold" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_PRESSURE_THRESHOLD"` + MaxSummaryChars int `json:"max_summary_chars" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_MAX_SUMMARY_CHARS"` + MaxTranscriptChars int `json:"max_transcript_chars" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_MAX_TRANSCRIPT_CHARS"` } type ChannelsConfig struct { @@ -402,6 +406,9 @@ func DefaultConfig() *Config { Mode: "summary", TriggerMessages: 60, KeepRecentMessages: 20, + TargetRatio: 0.35, + ProtectLastN: 12, + PressureThreshold: 0.8, MaxSummaryChars: 6000, MaxTranscriptChars: 20000, }, diff --git a/pkg/config/validate.go b/pkg/config/validate.go index 1963e8d..ac2886f 100644 --- a/pkg/config/validate.go +++ b/pkg/config/validate.go @@ -76,6 +76,15 @@ func Validate(cfg *Config) []error { if cc.KeepRecentMessages <= 0 { errs = append(errs, fmt.Errorf("agents.defaults.context_compaction.keep_recent_messages must be > 0 when enabled=true")) } + if cc.TargetRatio <= 0 || cc.TargetRatio >= 1 { + errs = append(errs, fmt.Errorf("agents.defaults.context_compaction.target_ratio must be > 0 and < 1 when enabled=true")) + } + if cc.ProtectLastN <= 0 { + errs = append(errs, fmt.Errorf("agents.defaults.context_compaction.protect_last_n must be > 0 when enabled=true")) + } + if cc.PressureThreshold <= 0 || cc.PressureThreshold > 1 { + errs = append(errs, fmt.Errorf("agents.defaults.context_compaction.pressure_threshold must be > 0 and <= 1 when enabled=true")) + } if cc.TriggerMessages > 0 && cc.KeepRecentMessages >= cc.TriggerMessages { errs = append(errs, fmt.Errorf("agents.defaults.context_compaction.keep_recent_messages must be < trigger_messages")) } @@ -395,6 +404,9 @@ func validateSubagents(cfg *Config) []error { if raw.Runtime.MaxParallelRuns < 0 { errs = append(errs, fmt.Errorf("agents.subagents.%s.runtime.max_parallel_runs must be >= 0", id)) } + if raw.Runtime.MaxToolIterations < 0 { + errs = append(errs, fmt.Errorf("agents.subagents.%s.runtime.max_tool_iterations must be >= 0", id)) + } if raw.Tools.MaxParallelCalls < 0 { errs = append(errs, fmt.Errorf("agents.subagents.%s.tools.max_parallel_calls must be >= 0", id)) } diff --git a/pkg/jsonlog/store.go b/pkg/jsonlog/store.go new file mode 100644 index 0000000..7ac92ca --- /dev/null +++ b/pkg/jsonlog/store.go @@ -0,0 +1,96 @@ +package jsonlog + +import ( + "bufio" + "encoding/json" + "fmt" + "os" + "path/filepath" +) + +// AppendLine appends a single JSON-encoded line and returns the resulting file size. +func AppendLine(path string, value interface{}) (int64, error) { + data, err := json.Marshal(value) + if err != nil { + return 0, err + } + return AppendRawLine(path, data) +} + +// AppendRawLine appends a raw JSON line and returns the resulting file size. +func AppendRawLine(path string, line []byte) (int64, error) { + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return 0, err + } + f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) + if err != nil { + return 0, err + } + defer f.Close() + if _, err := f.Write(append(line, '\n')); err != nil { + return 0, err + } + st, err := f.Stat() + if err != nil { + return 0, err + } + return st.Size(), nil +} + +// Scan walks each non-empty JSONL line in order. +func Scan(path string, fn func(line []byte) error) error { + f, err := os.Open(path) + if err != nil { + return err + } + defer f.Close() + + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 0, 64*1024), 8*1024*1024) + for scanner.Scan() { + line := append([]byte(nil), scanner.Bytes()...) + if len(line) == 0 { + continue + } + if err := fn(line); err != nil { + return err + } + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("scan %s: %w", path, err) + } + return nil +} + +// FileSize returns 0 when the file does not exist. +func FileSize(path string) (int64, error) { + st, err := os.Stat(path) + if err != nil { + if os.IsNotExist(err) { + return 0, nil + } + return 0, err + } + return st.Size(), nil +} + +// ReadJSON reads a JSON sidecar file into dst. +func ReadJSON(path string, dst interface{}) error { + data, err := os.ReadFile(path) + if err != nil { + return err + } + return json.Unmarshal(data, dst) +} + +// WriteJSON writes a JSON sidecar file atomically enough for local runtime use. +func WriteJSON(path string, value interface{}) error { + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return err + } + data, err := json.MarshalIndent(value, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0644) +} diff --git a/pkg/providers/execution.go b/pkg/providers/execution.go index 92ae768..b290c07 100644 --- a/pkg/providers/execution.go +++ b/pkg/providers/execution.go @@ -4,8 +4,10 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "net/http" + "time" ) func newProviderExecutionError(code, message, stage string, retryable bool, source string) *ProviderExecutionError { @@ -18,6 +20,18 @@ func newProviderExecutionError(code, message, stage string, retryable bool, sour } } +func NewProviderExecutionError(code, message, stage string, retryable bool, source string) *ProviderExecutionError { + return newProviderExecutionError(code, message, stage, retryable, source) +} + +func ExecutionErrorCode(err error) string { + var execErr *ProviderExecutionError + if errors.As(err, &execErr) && execErr != nil { + return execErr.Code + } + return "" +} + func (p *HTTPProvider) executeJSONAttempts(ctx context.Context, endpoint string, payload interface{}, mutate func(*http.Request, authAttempt), classify func(int, []byte) (oauthFailureReason, bool)) (ProviderExecutionResult, error) { jsonData, err := json.Marshal(payload) if err != nil { @@ -42,12 +56,13 @@ func (p *HTTPProvider) executeJSONAttempts(ctx context.Context, endpoint string, } body, status, ctype, err := p.doJSONAttempt(req, attempt) if err != nil { + execErr := newProviderExecutionError("request_failed", err.Error(), "request", false, p.providerName) return ProviderExecutionResult{ StatusCode: status, ContentType: ctype, AttemptKind: attempt.kind, - Error: newProviderExecutionError("request_failed", err.Error(), "request", false, p.providerName), - }, err + Error: execErr, + }, execErr } reason, retry := classify(status, body) last = ProviderExecutionResult{ @@ -78,8 +93,10 @@ func (p *HTTPProvider) executeStreamAttempts(ctx context.Context, endpoint strin } var last ProviderExecutionResult for _, attempt := range attempts { - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData)) + attemptCtx, cancel := context.WithCancel(ctx) + req, err := http.NewRequestWithContext(attemptCtx, http.MethodPost, endpoint, bytes.NewReader(jsonData)) if err != nil { + cancel() return ProviderExecutionResult{Error: newProviderExecutionError("request_build_failed", err.Error(), "request", false, p.providerName)}, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") @@ -89,14 +106,21 @@ func (p *HTTPProvider) executeStreamAttempts(ctx context.Context, endpoint strin if mutate != nil { mutate(req, attempt) } - body, status, ctype, quotaHit, err := p.doStreamAttempt(req, attempt, onEvent) + streamOptions := streamAttemptTimeouts(ctx) + body, status, ctype, quotaHit, err := p.doStreamAttempt(req, attempt, onEvent, streamOptions, cancel) + cancel() if err != nil { + code := "stream_failed" + if streamOptions.staleTriggered { + code = "stream_stale" + } + execErr := newProviderExecutionError(code, err.Error(), "request", true, p.providerName) return ProviderExecutionResult{ StatusCode: status, ContentType: ctype, AttemptKind: attempt.kind, - Error: newProviderExecutionError("stream_failed", err.Error(), "request", false, p.providerName), - }, err + Error: execErr, + }, execErr } reason, _ := classifyOAuthFailure(status, body) last = ProviderExecutionResult{ @@ -115,3 +139,47 @@ func (p *HTTPProvider) executeStreamAttempts(ctx context.Context, endpoint strin } return last, nil } + +type streamAttemptOptions struct { + firstDeltaTimeout time.Duration + idleDeltaTimeout time.Duration + staleTriggered bool +} + +func streamAttemptTimeouts(ctx context.Context) *streamAttemptOptions { + opts := &streamAttemptOptions{ + firstDeltaTimeout: 20 * time.Second, + idleDeltaTimeout: 45 * time.Second, + } + if deadline, ok := ctx.Deadline(); ok { + remaining := time.Until(deadline) + if remaining > 0 { + first := remaining / 4 + idle := remaining / 3 + if remaining < 5*time.Second { + first = maxDuration(100*time.Millisecond, first) + } else if first < 5*time.Second { + first = 5 * time.Second + } + if remaining < 15*time.Second { + idle = maxDuration(250*time.Millisecond, idle) + } else if idle < 15*time.Second { + idle = 15 * time.Second + } + if first < opts.firstDeltaTimeout { + opts.firstDeltaTimeout = first + } + if idle < opts.idleDeltaTimeout { + opts.idleDeltaTimeout = idle + } + } + } + return opts +} + +func maxDuration(a, b time.Duration) time.Duration { + if a > b { + return a + } + return b +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 7da55ff..94c2d92 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -983,7 +983,7 @@ func (p *HTTPProvider) doJSONAttempt(req *http.Request, attempt authAttempt) ([] return body, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil } -func (p *HTTPProvider) doStreamAttempt(req *http.Request, attempt authAttempt, onEvent func(string)) ([]byte, int, string, bool, error) { +func (p *HTTPProvider) doStreamAttempt(req *http.Request, attempt authAttempt, onEvent func(string), options *streamAttemptOptions, cancel context.CancelFunc) ([]byte, int, string, bool, error) { client, err := p.httpClientForAttempt(attempt) if err != nil { return nil, 0, "", false, err @@ -1006,7 +1006,39 @@ func (p *HTTPProvider) doStreamAttempt(req *http.Request, attempt authAttempt, o scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) var dataLines []string var finalJSON []byte + lastActivity := time.Now() + firstDeltaSeen := false + done := make(chan struct{}) + if options != nil && cancel != nil { + go func() { + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + defer close(done) + for { + select { + case <-req.Context().Done(): + return + case <-ticker.C: + timeout := options.firstDeltaTimeout + if firstDeltaSeen { + timeout = options.idleDeltaTimeout + } + if timeout > 0 && time.Since(lastActivity) > timeout { + options.staleTriggered = true + cancel() + return + } + } + } + }() + } else { + close(done) + } + defer func() { + <-done + }() for scanner.Scan() { + lastActivity = time.Now() line := scanner.Text() if strings.TrimSpace(line) == "" { if len(dataLines) > 0 { @@ -1015,6 +1047,7 @@ func (p *HTTPProvider) doStreamAttempt(req *http.Request, attempt authAttempt, o if strings.TrimSpace(payload) == "[DONE]" { continue } + firstDeltaSeen = true if onEvent != nil { onEvent(payload) } @@ -1039,6 +1072,9 @@ func (p *HTTPProvider) doStreamAttempt(req *http.Request, attempt authAttempt, o } } if err := scanner.Err(); err != nil { + if options != nil && options.staleTriggered { + return nil, resp.StatusCode, ctype, false, fmt.Errorf("stream stale: %w", err) + } return nil, resp.StatusCode, ctype, false, fmt.Errorf("failed to read stream: %w", err) } if len(finalJSON) == 0 { diff --git a/pkg/providers/iflow_provider.go b/pkg/providers/iflow_provider.go index fd9de26..f19070e 100644 --- a/pkg/providers/iflow_provider.go +++ b/pkg/providers/iflow_provider.go @@ -254,7 +254,7 @@ func doIFlowStreamWithAttempts(ctx context.Context, base *HTTPProvider, payload onDelta(txt) } } - }) + }, nil, nil) if err != nil { return nil, 0, "", err } diff --git a/pkg/providers/oauth_test.go b/pkg/providers/oauth_test.go index 2856674..9ffc332 100644 --- a/pkg/providers/oauth_test.go +++ b/pkg/providers/oauth_test.go @@ -216,7 +216,7 @@ func TestHTTPProviderOpenAICompatStreamMergesLateUsage(t *testing.T) { if err != nil { t.Fatalf("new request failed: %v", err) } - body, status, _, _, err := provider.doStreamAttempt(req, authAttempt{kind: "api_key", token: "token"}, nil) + body, status, _, _, err := provider.doStreamAttempt(req, authAttempt{kind: "api_key", token: "token"}, nil, nil, nil) if err != nil { t.Fatalf("stream attempt failed: %v", err) } diff --git a/pkg/providers/openai_compat_provider.go b/pkg/providers/openai_compat_provider.go index d409c4d..f4a2e71 100644 --- a/pkg/providers/openai_compat_provider.go +++ b/pkg/providers/openai_compat_provider.go @@ -143,7 +143,7 @@ func doOpenAICompatStreamWithAttempts(ctx context.Context, base *HTTPProvider, p onDelta(txt) } } - }) + }, nil, nil) if err != nil { return nil, 0, "", err } diff --git a/pkg/providers/types.go b/pkg/providers/types.go index d07a7f7..3846538 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -101,6 +101,19 @@ type ProviderExecutionError struct { Source string `json:"source,omitempty"` } +func (e *ProviderExecutionError) Error() string { + if e == nil { + return "" + } + if e.Message != "" { + return e.Message + } + if e.Code != "" { + return e.Code + } + return "provider execution error" +} + type ProviderExecutionResult struct { Body []byte `json:"-"` StatusCode int `json:"status_code,omitempty"` diff --git a/pkg/session/manager.go b/pkg/session/manager.go index 430b0a6..500827b 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -1,21 +1,34 @@ package session import ( - "bufio" "crypto/sha1" "encoding/hex" "encoding/json" "fmt" "os" "path/filepath" + "regexp" + "sort" "strconv" "strings" "sync" "time" + "unicode" + "github.com/YspCoder/clawgo/pkg/jsonlog" "github.com/YspCoder/clawgo/pkg/providers" ) +const ( + defaultSessionSegmentMaxMessages = 200 + defaultSessionSegmentMaxBytes = 2 * 1024 * 1024 + defaultSessionPromptLoadSegments = 2 + maxTokenRefsPerSessionToken = 48 + maxSearchSnippetsPerSession = 3 +) + +var archiveSegmentRe = regexp.MustCompile(`^(?P.+)\.(?P\d{4})\.jsonl$`) + type Session struct { Key string `json:"key"` SessionID string `json:"session_id,omitempty"` @@ -28,12 +41,18 @@ type Session struct { Created time.Time `json:"created"` Updated time.Time `json:"updated"` mu sync.RWMutex + segments []sessionSegmentMeta + nextSeq int + index *sessionIndexFile } type SessionManager struct { - sessions map[string]*Session - mu sync.RWMutex - storage string + sessions map[string]*Session + mu sync.RWMutex + storage string + segmentMaxMessages int + segmentMaxBytes int64 + promptLoadSegments int } type openClawEvent struct { @@ -51,14 +70,100 @@ type openClawEvent struct { } `json:"message,omitempty"` } +type sessionSegmentMeta struct { + Name string `json:"name"` + Archived bool `json:"archived,omitempty"` + FirstSeq int `json:"first_seq,omitempty"` + LastSeq int `json:"last_seq,omitempty"` + MessageCount int `json:"message_count,omitempty"` + LastOffset int64 `json:"last_offset,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` +} + +type sessionMetaFile struct { + Version int `json:"version"` + SessionKey string `json:"session_key"` + SessionID string `json:"session_id,omitempty"` + Kind string `json:"kind,omitempty"` + Summary string `json:"summary,omitempty"` + CompactionCount int `json:"compaction_count,omitempty"` + LastLanguage string `json:"last_language,omitempty"` + PreferredLanguage string `json:"preferred_language,omitempty"` + MessageCount int `json:"message_count,omitempty"` + CreatedAt int64 `json:"created_at,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + NextSeq int `json:"next_seq,omitempty"` + Segments []sessionSegmentMeta `json:"segments,omitempty"` +} + +type sessionIndexRef struct { + Seq int `json:"seq"` + Role string `json:"role,omitempty"` + Segment string `json:"segment,omitempty"` + Snippet string `json:"snippet,omitempty"` +} + +type sessionIndexFile struct { + Version int `json:"version"` + SessionKey string `json:"session_key"` + LastSeq int `json:"last_seq,omitempty"` + LastOffset int64 `json:"last_offset,omitempty"` + Segment string `json:"segment,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + Tokens map[string][]sessionIndexRef `json:"tokens,omitempty"` +} + +type SessionSearchSnippet struct { + Seq int `json:"seq"` + Role string `json:"role,omitempty"` + Segment string `json:"segment,omitempty"` + Content string `json:"content,omitempty"` +} + +type SessionSearchResult struct { + Key string `json:"key"` + Kind string `json:"kind,omitempty"` + Summary string `json:"summary,omitempty"` + UpdatedAt time.Time `json:"updated_at"` + Score int `json:"score"` + Snippets []SessionSearchSnippet `json:"snippets,omitempty"` +} + +type SessionCompactionSnapshot struct { + Key string + History []providers.Message + Summary string + NextSeq int +} + +type sessionsIndexEntry struct { + SessionID string `json:"sessionId"` + SessionKey string `json:"sessionKey"` + UpdatedAt int64 `json:"updatedAt"` + Kind string `json:"kind"` + ChatType string `json:"chatType"` + CompactionCount int `json:"compactionCount"` + SessionFile string `json:"sessionFile,omitempty"` + Summary string `json:"summary,omitempty"` + LastLanguage string `json:"lastLanguage,omitempty"` + PreferredLanguage string `json:"preferredLanguage,omitempty"` +} + +type appendMessageResult struct { + refreshSessionsIndex bool +} + func NewSessionManager(storage string) *SessionManager { sm := &SessionManager{ - sessions: make(map[string]*Session), - storage: storage, + sessions: make(map[string]*Session), + storage: storage, + segmentMaxMessages: readPositiveIntEnv("CLAWGO_SESSION_SEGMENT_MAX_MESSAGES", defaultSessionSegmentMaxMessages), + segmentMaxBytes: int64(readPositiveIntEnv("CLAWGO_SESSION_SEGMENT_MAX_BYTES", defaultSessionSegmentMaxBytes)), + promptLoadSegments: readPositiveIntEnv("CLAWGO_SESSION_PROMPT_LOAD_SEGMENTS", defaultSessionPromptLoadSegments), } if storage != "" { - os.MkdirAll(storage, 0755) + _ = os.MkdirAll(storage, 0755) sm.cleanupArchivedSessions() sm.loadSessions() } @@ -67,133 +172,193 @@ func NewSessionManager(storage string) *SessionManager { } func (sm *SessionManager) GetOrCreate(key string) *Session { + session, _ := sm.getOrCreate(key) + return session +} + +func (sm *SessionManager) getOrCreate(key string) (*Session, bool) { sm.mu.RLock() session, ok := sm.sessions[key] sm.mu.RUnlock() - if ok { - return session + return session, false } sm.mu.Lock() defer sm.mu.Unlock() - - // Re-check existence after acquiring Write lock if session, ok = sm.sessions[key]; ok { - return session + return session, false } - + now := time.Now() session = &Session{ Key: key, SessionID: deriveSessionID(key), Kind: detectSessionKind(key), Messages: []providers.Message{}, - Created: time.Now(), - Updated: time.Now(), + Created: now, + Updated: now, + nextSeq: 1, } sm.sessions[key] = session - - return session + return session, true } func (sm *SessionManager) AddMessage(sessionKey, role, content string) { - sm.AddMessageFull(sessionKey, providers.Message{ - Role: role, - Content: content, - }) + sm.AddMessageFull(sessionKey, providers.Message{Role: role, Content: content}) } func (sm *SessionManager) AddMessageFull(sessionKey string, msg providers.Message) { - session := sm.GetOrCreate(sessionKey) + session, created := sm.getOrCreate(sessionKey) + persisted := false + refreshIndex := created session.mu.Lock() session.Messages = append(session.Messages, msg) session.Updated = time.Now() + if session.Created.IsZero() { + session.Created = session.Updated + } + appendResult, err := sm.appendMessageLocked(session, msg) + persisted = err == nil + refreshIndex = refreshIndex || appendResult.refreshSessionsIndex session.mu.Unlock() - - // Persist immediately (append-only). - sm.appendMessage(sessionKey, msg) + if persisted && refreshIndex { + _ = sm.writeOpenClawSessionsIndex() + } } -func (sm *SessionManager) appendMessage(sessionKey string, msg providers.Message) error { - if sm.storage == "" { +func (sm *SessionManager) GetPromptHistory(key string) []providers.Message { + sm.mu.RLock() + session, ok := sm.sessions[key] + sm.mu.RUnlock() + if !ok { return nil } - - sessionPath := filepath.Join(sm.storage, sessionKey+".jsonl") - f, err := os.OpenFile(sessionPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - return err - } - defer f.Close() - - event := toOpenClawMessageEvent(msg) - data, err := json.Marshal(event) - if err != nil { - return err - } - - if _, err = f.Write(append(data, '\n')); err != nil { - return err - } - return sm.writeOpenClawSessionsIndex() + session.mu.RLock() + defer session.mu.RUnlock() + history := make([]providers.Message, len(session.Messages)) + copy(history, session.Messages) + return history } func (sm *SessionManager) GetHistory(key string) []providers.Message { sm.mu.RLock() session, ok := sm.sessions[key] sm.mu.RUnlock() - if !ok { return []providers.Message{} } session.mu.RLock() - defer session.mu.RUnlock() + segments := append([]sessionSegmentMeta(nil), session.segments...) + loaded := make([]providers.Message, len(session.Messages)) + copy(loaded, session.Messages) + session.mu.RUnlock() - history := make([]providers.Message, len(session.Messages)) - copy(history, session.Messages) - return history + if len(segments) == 0 { + return loaded + } + all, err := sm.loadMessagesForSegments(segments) + if err != nil || len(all) == 0 { + return loaded + } + return all } func (sm *SessionManager) GetSummary(key string) string { sm.mu.RLock() session, ok := sm.sessions[key] sm.mu.RUnlock() - if !ok { return "" } - session.mu.RLock() defer session.mu.RUnlock() - return session.Summary } -func (sm *SessionManager) CompactSession(key string, keepLast int, note string) bool { - sm.mu.RLock() - session, ok := sm.sessions[key] - sm.mu.RUnlock() - if !ok { - return false - } - +func (sm *SessionManager) SetSummary(key, summary string) { + session := sm.GetOrCreate(key) session.mu.Lock() - defer session.mu.Unlock() - if keepLast <= 0 || len(session.Messages) <= keepLast { + trimmed := strings.TrimSpace(summary) + if session.Summary == trimmed { + session.mu.Unlock() + return + } + session.Summary = trimmed + session.Updated = time.Now() + _ = sm.persistSidecarsLocked(session) + session.mu.Unlock() + _ = sm.writeOpenClawSessionsIndex() +} + +func (sm *SessionManager) ApplyCompaction(key string, keep []providers.Message, summary string) bool { + session := sm.GetOrCreate(key) + session.mu.Lock() + applied := sm.applyCompactionLocked(session, keep, summary) + session.mu.Unlock() + if applied { + _ = sm.writeOpenClawSessionsIndex() + } + return applied +} + +func (sm *SessionManager) ApplyCompactionIfUnchanged(key string, baseNextSeq int, baseSummary string, keep []providers.Message, summary string) bool { + session := sm.GetOrCreate(key) + session.mu.Lock() + if session.nextSeq != baseNextSeq || session.Summary != baseSummary { + session.mu.Unlock() return false } - session.Messages = session.Messages[len(session.Messages)-keepLast:] + applied := sm.applyCompactionLocked(session, keep, summary) + session.mu.Unlock() + if applied { + _ = sm.writeOpenClawSessionsIndex() + } + return applied +} + +func (sm *SessionManager) applyCompactionLocked(session *Session, keep []providers.Message, summary string) bool { + if session == nil { + return false + } + if len(session.Messages) == 0 && len(keep) == 0 { + return false + } + session.Messages = append([]providers.Message(nil), keep...) + if trimmed := strings.TrimSpace(summary); trimmed != "" { + session.Summary = trimmed + } session.CompactionCount++ - if strings.TrimSpace(note) != "" { + session.Updated = time.Now() + if err := sm.persistSidecarsLocked(session); err != nil { + return false + } + return true +} + +func (sm *SessionManager) CompactSession(key string, keepLast int, note string) bool { + session := sm.GetOrCreate(key) + session.mu.Lock() + if keepLast <= 0 || len(session.Messages) <= keepLast { + session.mu.Unlock() + return false + } + session.Messages = append([]providers.Message(nil), session.Messages[len(session.Messages)-keepLast:]...) + session.CompactionCount++ + if trimmed := strings.TrimSpace(note); trimmed != "" { if strings.TrimSpace(session.Summary) == "" { - session.Summary = note + session.Summary = trimmed } else { - session.Summary += "\n" + note + session.Summary += "\n" + trimmed } } session.Updated = time.Now() + err := sm.persistSidecarsLocked(session) + session.mu.Unlock() + if err == nil { + _ = sm.writeOpenClawSessionsIndex() + } return true } @@ -204,21 +369,45 @@ func (sm *SessionManager) GetLanguagePreferences(key string) (preferred string, if !ok { return "", "" } - session.mu.RLock() defer session.mu.RUnlock() return session.PreferredLanguage, session.LastLanguage } +func (sm *SessionManager) CompactionSnapshot(key string) SessionCompactionSnapshot { + sm.mu.RLock() + session, ok := sm.sessions[key] + sm.mu.RUnlock() + if !ok || session == nil { + return SessionCompactionSnapshot{Key: key} + } + session.mu.RLock() + defer session.mu.RUnlock() + history := make([]providers.Message, len(session.Messages)) + copy(history, session.Messages) + return SessionCompactionSnapshot{ + Key: key, + History: history, + Summary: session.Summary, + NextSeq: session.nextSeq, + } +} + func (sm *SessionManager) SetLastLanguage(key, lang string) { if strings.TrimSpace(lang) == "" { return } session := sm.GetOrCreate(key) session.mu.Lock() + if session.LastLanguage == lang { + session.mu.Unlock() + return + } session.LastLanguage = lang session.Updated = time.Now() + _ = sm.persistSidecarsLocked(session) session.mu.Unlock() + _ = sm.writeOpenClawSessionsIndex() } func (sm *SessionManager) SetPreferredLanguage(key, lang string) { @@ -227,17 +416,27 @@ func (sm *SessionManager) SetPreferredLanguage(key, lang string) { } session := sm.GetOrCreate(key) session.mu.Lock() + if session.PreferredLanguage == lang { + session.mu.Unlock() + return + } session.PreferredLanguage = lang session.Updated = time.Now() + _ = sm.persistSidecarsLocked(session) session.mu.Unlock() + _ = sm.writeOpenClawSessionsIndex() } func (sm *SessionManager) Save(session *Session) error { - // Messages are persisted incrementally via AddMessageFull. - // Metadata is now centralized in sessions.json (OpenClaw-style index). - if sm.storage == "" { + if sm.storage == "" || session == nil { return nil } + session.mu.Lock() + if err := sm.persistSidecarsLocked(session); err != nil { + session.mu.Unlock() + return err + } + session.mu.Unlock() return sm.writeOpenClawSessionsIndex() } @@ -254,6 +453,7 @@ func (sm *SessionManager) Keys() []string { for k := range sm.sessions { keys = append(keys, k) } + sort.Strings(keys) return keys } @@ -276,12 +476,92 @@ func (sm *SessionManager) List(limit int) []Session { }) s.mu.RUnlock() } + sort.Slice(items, func(i, j int) bool { return items[i].Updated.After(items[j].Updated) }) if limit > 0 && len(items) > limit { - return items[:limit] + items = items[:limit] } return items } +func (sm *SessionManager) Search(query string, kinds []string, excludeKey string, limit int) []SessionSearchResult { + terms := tokenizeQueryText(query) + if len(terms) == 0 { + return nil + } + if limit <= 0 { + limit = 5 + } + kindSet := make(map[string]struct{}, len(kinds)) + for _, item := range kinds { + if v := strings.ToLower(strings.TrimSpace(item)); v != "" { + kindSet[v] = struct{}{} + } + } + + sm.mu.RLock() + keys := make([]string, 0, len(sm.sessions)) + for key := range sm.sessions { + keys = append(keys, key) + } + sm.mu.RUnlock() + sort.Strings(keys) + + results := make([]SessionSearchResult, 0, len(keys)) + for _, key := range keys { + if key == strings.TrimSpace(excludeKey) { + continue + } + sm.mu.RLock() + session := sm.sessions[key] + sm.mu.RUnlock() + if session == nil { + continue + } + session.mu.RLock() + kind := session.Kind + summary := session.Summary + updated := session.Updated + index := cloneSessionIndex(session.index) + session.mu.RUnlock() + if len(kindSet) > 0 { + if _, ok := kindSet[strings.ToLower(strings.TrimSpace(kind))]; !ok { + continue + } + } + + if index != nil { + if result, ok := searchSessionIndex(key, kind, summary, updated, terms, index); ok { + results = append(results, result) + continue + } + } + indexPath := sm.sessionIndexPath(key) + if index, err := sm.readIndexFile(indexPath); err == nil { + if result, ok := searchSessionIndex(key, kind, summary, updated, terms, index); ok { + results = append(results, result) + continue + } + } + if result, ok := sm.searchSessionByScan(key, kind, summary, updated, terms); ok { + results = append(results, result) + } + } + + sort.Slice(results, func(i, j int) bool { + if results[i].Score != results[j].Score { + return results[i].Score > results[j].Score + } + if !results[i].UpdatedAt.Equal(results[j].UpdatedAt) { + return results[i].UpdatedAt.After(results[j].UpdatedAt) + } + return results[i].Key < results[j].Key + }) + if len(results) > limit { + results = results[:limit] + } + return results +} + func toOpenClawMessageEvent(msg providers.Message) openClawEvent { role := strings.TrimSpace(strings.ToLower(msg.Role)) mappedRole := role @@ -346,7 +626,6 @@ func fromJSONLLine(line []byte) (providers.Message, bool) { func deriveSessionID(key string) string { sum := sha1.Sum([]byte("clawgo-session:" + key)) h := hex.EncodeToString(sum[:]) - // UUID-like deterministic id return h[0:8] + "-" + h[8:12] + "-" + h[12:16] + "-" + h[16:20] + "-" + h[20:32] } @@ -372,24 +651,24 @@ func (sm *SessionManager) writeOpenClawSessionsIndex() error { } sm.mu.RLock() defer sm.mu.RUnlock() - index := map[string]map[string]interface{}{} + index := map[string]sessionsIndexEntry{} for key, s := range sm.sessions { s.mu.RLock() - sessionFile := filepath.Join(sm.storage, key+".jsonl") sid := strings.TrimSpace(s.SessionID) if sid == "" { sid = deriveSessionID(key) } - entry := map[string]interface{}{ - "sessionId": sid, - "sessionKey": key, - "updatedAt": s.Updated.UnixMilli(), - "systemSent": true, - "abortedLastRun": false, - "compactionCount": s.CompactionCount, - "chatType": mapKindToChatType(s.Kind), - "sessionFile": sessionFile, - "kind": s.Kind, + entry := sessionsIndexEntry{ + SessionID: sid, + SessionKey: key, + UpdatedAt: s.Updated.UnixMilli(), + ChatType: mapKindToChatType(s.Kind), + Kind: s.Kind, + CompactionCount: s.CompactionCount, + SessionFile: filepath.Join(sm.storage, activeSegmentFilename(key)), + Summary: s.Summary, + LastLanguage: s.LastLanguage, + PreferredLanguage: s.PreferredLanguage, } s.mu.RUnlock() index[key] = entry @@ -398,10 +677,7 @@ func (sm *SessionManager) writeOpenClawSessionsIndex() error { if err != nil { return err } - if err := os.WriteFile(filepath.Join(sm.storage, "sessions.json"), data, 0644); err != nil { - return err - } - return nil + return os.WriteFile(filepath.Join(sm.storage, "sessions.json"), data, 0644) } func (sm *SessionManager) cleanupArchivedSessions() { @@ -450,71 +726,1023 @@ func mapKindToChatType(kind string) string { } func (sm *SessionManager) loadSessions() error { - // 1) Load sessions index first (sessions.json as source of truth) - indexPath := filepath.Join(sm.storage, "sessions.json") - if data, err := os.ReadFile(indexPath); err == nil { - var index map[string]struct { - SessionID string `json:"sessionId"` - SessionKey string `json:"sessionKey"` - UpdatedAt int64 `json:"updatedAt"` - Kind string `json:"kind"` - ChatType string `json:"chatType"` - CompactionCount int `json:"compactionCount"` - } - if err := json.Unmarshal(data, &index); err == nil { - for key, row := range index { - session := sm.GetOrCreate(key) - session.mu.Lock() - if strings.TrimSpace(row.SessionID) != "" { - session.SessionID = row.SessionID - } - if strings.TrimSpace(row.Kind) != "" { - session.Kind = row.Kind - } else if strings.TrimSpace(row.ChatType) == "direct" { - session.Kind = "main" - } - if row.UpdatedAt > 0 { - session.Updated = time.UnixMilli(row.UpdatedAt) - } - session.CompactionCount = row.CompactionCount - session.mu.Unlock() - } - } - } - - // 2) Load JSONL histories - files, err := os.ReadDir(sm.storage) + fallbackIndex := sm.readSessionsIndex() + keys, err := sm.discoverSessionKeys() if err != nil { return err } - for _, file := range files { - if file.IsDir() || filepath.Ext(file.Name()) != ".jsonl" { + for _, key := range keys { + session, loadErr := sm.loadSessionFromDisk(key, fallbackIndex[key]) + if loadErr != nil { + return loadErr + } + if session == nil { continue } - sessionKey := strings.TrimSuffix(file.Name(), ".jsonl") - session := sm.GetOrCreate(sessionKey) - - f, err := os.Open(filepath.Join(sm.storage, file.Name())) - if err != nil { - continue - } - scanner := bufio.NewScanner(f) - scanner.Buffer(make([]byte, 0, 64*1024), 8*1024*1024) - session.mu.Lock() - for scanner.Scan() { - if msg, ok := fromJSONLLine(scanner.Bytes()); ok { - session.Messages = append(session.Messages, msg) - } - } - session.mu.Unlock() - closeErr := f.Close() - if err := scanner.Err(); err != nil { - return fmt.Errorf("scan session file %s: %w", file.Name(), err) - } - if closeErr != nil { - return fmt.Errorf("close session file %s: %w", file.Name(), closeErr) - } + sm.sessions[key] = session } - return sm.writeOpenClawSessionsIndex() } + +func (sm *SessionManager) readSessionsIndex() map[string]sessionsIndexEntry { + if sm.storage == "" { + return nil + } + data, err := os.ReadFile(filepath.Join(sm.storage, "sessions.json")) + if err != nil { + return nil + } + var index map[string]sessionsIndexEntry + if err := json.Unmarshal(data, &index); err != nil { + return nil + } + return index +} + +func (sm *SessionManager) loadSessionFromDisk(key string, fallback sessionsIndexEntry) (*Session, error) { + meta, err := sm.loadOrRebuildMeta(key) + if err != nil { + return nil, err + } + index, indexErr := sm.readIndexFile(sm.sessionIndexPath(key)) + if indexErr != nil { + index = nil + } + if meta == nil { + if strings.TrimSpace(fallback.SessionKey) == "" && key == "" { + return nil, nil + } + now := time.UnixMilli(fallback.UpdatedAt) + if fallback.UpdatedAt <= 0 { + now = time.Now() + } + return &Session{ + Key: key, + SessionID: firstNonEmpty(fallback.SessionID, deriveSessionID(key)), + Kind: firstNonEmpty(fallback.Kind, detectSessionKind(key)), + Summary: fallback.Summary, + CompactionCount: fallback.CompactionCount, + LastLanguage: fallback.LastLanguage, + PreferredLanguage: fallback.PreferredLanguage, + Created: now, + Updated: now, + nextSeq: 1, + index: index, + }, nil + } + + workingMessages, err := sm.loadWorkingSet(meta.Segments) + if err != nil { + return nil, err + } + created := time.UnixMilli(meta.CreatedAt) + if meta.CreatedAt <= 0 { + created = time.Now() + } + updated := time.UnixMilli(meta.UpdatedAt) + if meta.UpdatedAt <= 0 { + updated = created + } + if strings.TrimSpace(meta.Summary) == "" && strings.TrimSpace(fallback.Summary) != "" { + meta.Summary = fallback.Summary + } + if strings.TrimSpace(meta.LastLanguage) == "" && strings.TrimSpace(fallback.LastLanguage) != "" { + meta.LastLanguage = fallback.LastLanguage + } + if strings.TrimSpace(meta.PreferredLanguage) == "" && strings.TrimSpace(fallback.PreferredLanguage) != "" { + meta.PreferredLanguage = fallback.PreferredLanguage + } + return &Session{ + Key: key, + SessionID: firstNonEmpty(meta.SessionID, fallback.SessionID, deriveSessionID(key)), + Kind: firstNonEmpty(meta.Kind, fallback.Kind, detectSessionKind(key)), + Messages: workingMessages, + Summary: firstNonEmpty(meta.Summary, fallback.Summary), + CompactionCount: maxInt(meta.CompactionCount, fallback.CompactionCount), + LastLanguage: firstNonEmpty(meta.LastLanguage, fallback.LastLanguage), + PreferredLanguage: firstNonEmpty(meta.PreferredLanguage, fallback.PreferredLanguage), + Created: created, + Updated: updated, + segments: append([]sessionSegmentMeta(nil), meta.Segments...), + nextSeq: maxInt(meta.NextSeq, meta.MessageCount+1, 1), + index: index, + }, nil +} + +func (sm *SessionManager) loadOrRebuildMeta(key string) (*sessionMetaFile, error) { + metaPath := sm.sessionMetaPath(key) + indexPath := sm.sessionIndexPath(key) + + var meta sessionMetaFile + metaValid := false + if err := jsonlog.ReadJSON(metaPath, &meta); err == nil && meta.Version > 0 { + if sm.metaMatchesStorage(key, &meta) { + metaValid = true + } + } + if metaValid { + if _, err := sm.readIndexFile(indexPath); err == nil { + return &meta, nil + } + } + rebuilt, err := sm.rebuildSidecars(key, &meta) + if err != nil { + return nil, err + } + return rebuilt, nil +} + +func (sm *SessionManager) metaMatchesStorage(key string, meta *sessionMetaFile) bool { + if meta == nil || len(meta.Segments) == 0 { + return false + } + for _, segment := range meta.Segments { + size, err := jsonlog.FileSize(filepath.Join(sm.storage, segment.Name)) + if err != nil { + return false + } + if strings.EqualFold(segment.Name, activeSegmentFilename(key)) && size != segment.LastOffset { + return false + } + } + return true +} + +func (sm *SessionManager) rebuildSidecars(key string, seed *sessionMetaFile) (*sessionMetaFile, error) { + segments, err := sm.discoverSessionSegments(key) + if err != nil { + return nil, err + } + if len(segments) == 0 { + return nil, nil + } + + meta := &sessionMetaFile{ + Version: 1, + SessionKey: key, + SessionID: deriveSessionID(key), + Kind: detectSessionKind(key), + Summary: strings.TrimSpace(seedValue(seed, func(v *sessionMetaFile) string { return v.Summary })), + CompactionCount: seedInt(seed, func(v *sessionMetaFile) int { return v.CompactionCount }), + LastLanguage: seedValue(seed, func(v *sessionMetaFile) string { return v.LastLanguage }), + PreferredLanguage: seedValue(seed, func(v *sessionMetaFile) string { return v.PreferredLanguage }), + Segments: make([]sessionSegmentMeta, 0, len(segments)), + } + index := sessionIndexFile{ + Version: 1, + SessionKey: key, + LastSeq: 0, + LastOffset: 0, + Segment: "", + UpdatedAt: time.Now().UnixMilli(), + Tokens: map[string][]sessionIndexRef{}, + } + + now := time.Now().UnixMilli() + seq := 0 + for _, name := range segments { + fullPath := filepath.Join(sm.storage, name) + size, err := jsonlog.FileSize(fullPath) + if err != nil { + return nil, err + } + seg := sessionSegmentMeta{ + Name: name, + Archived: !strings.EqualFold(name, activeSegmentFilename(key)), + LastOffset: size, + UpdatedAt: now, + } + if st, err := os.Stat(fullPath); err == nil { + seg.UpdatedAt = st.ModTime().UnixMilli() + if meta.CreatedAt == 0 || st.ModTime().UnixMilli() < meta.CreatedAt { + meta.CreatedAt = st.ModTime().UnixMilli() + } + if st.ModTime().UnixMilli() > meta.UpdatedAt { + meta.UpdatedAt = st.ModTime().UnixMilli() + } + } + if err := jsonlog.Scan(fullPath, func(line []byte) error { + msg, ok := fromJSONLLine(line) + if !ok { + return nil + } + seq++ + if seg.FirstSeq == 0 { + seg.FirstSeq = seq + } + seg.LastSeq = seq + seg.MessageCount++ + meta.MessageCount++ + appendTokens(index.Tokens, tokenizeIndexText(msg.Content), sessionIndexRef{ + Seq: seq, + Role: strings.ToLower(strings.TrimSpace(msg.Role)), + Segment: name, + Snippet: messageSnippet(msg.Content), + }) + return nil + }); err != nil { + return nil, err + } + meta.Segments = append(meta.Segments, seg) + } + if meta.CreatedAt == 0 { + meta.CreatedAt = now + } + if meta.UpdatedAt == 0 { + meta.UpdatedAt = now + } + meta.NextSeq = seq + 1 + if len(meta.Segments) > 0 { + last := meta.Segments[len(meta.Segments)-1] + index.LastSeq = last.LastSeq + index.LastOffset = last.LastOffset + index.Segment = last.Name + } + if err := sm.writeSidecarFiles(key, meta, &index); err != nil { + return nil, err + } + return meta, nil +} + +func (sm *SessionManager) loadWorkingSet(segments []sessionSegmentMeta) ([]providers.Message, error) { + if len(segments) == 0 { + return nil, nil + } + start := 0 + if sm.promptLoadSegments > 0 && len(segments) > sm.promptLoadSegments { + start = len(segments) - sm.promptLoadSegments + } + return sm.loadMessagesForSegments(segments[start:]) +} + +func (sm *SessionManager) loadMessagesForSegments(segments []sessionSegmentMeta) ([]providers.Message, error) { + out := make([]providers.Message, 0) + for _, segment := range segments { + path := filepath.Join(sm.storage, segment.Name) + if err := jsonlog.Scan(path, func(line []byte) error { + msg, ok := fromJSONLLine(line) + if ok { + out = append(out, msg) + } + return nil + }); err != nil { + if os.IsNotExist(err) { + continue + } + return nil, err + } + } + return out, nil +} + +func (sm *SessionManager) GetHistoryWindow(key string, around, before, after, limit int) []providers.Message { + sm.mu.RLock() + session, ok := sm.sessions[key] + sm.mu.RUnlock() + if !ok { + return nil + } + session.mu.RLock() + segments := append([]sessionSegmentMeta(nil), session.segments...) + session.mu.RUnlock() + if len(segments) == 0 { + return nil + } + startSeq, endSeq := computeHistorySeqWindow(segments, around, before, after, limit) + if endSeq < startSeq { + return nil + } + selected := make([]sessionSegmentMeta, 0, len(segments)) + for _, segment := range segments { + if segment.LastSeq < startSeq || segment.FirstSeq > endSeq { + continue + } + selected = append(selected, segment) + } + if len(selected) == 0 { + return nil + } + all, err := sm.loadMessagesForSegments(selected) + if err != nil { + return nil + } + out := make([]providers.Message, 0, len(all)) + seq := 0 + for _, segment := range selected { + for i := 0; i < segment.MessageCount && seq < len(all); i++ { + currentSeq := segment.FirstSeq + i + msg := all[seq] + seq++ + if currentSeq < startSeq || currentSeq > endSeq { + continue + } + out = append(out, msg) + } + } + return out +} + +func (sm *SessionManager) appendMessageLocked(session *Session, msg providers.Message) (appendMessageResult, error) { + if sm.storage == "" { + return appendMessageResult{}, nil + } + if err := sm.ensureActiveSegmentLocked(session); err != nil { + return appendMessageResult{}, err + } + result := appendMessageResult{} + if sm.shouldRolloverLocked(session) { + if err := sm.rolloverLocked(session); err != nil { + return appendMessageResult{}, err + } + result.refreshSessionsIndex = true + if err := sm.ensureActiveSegmentLocked(session); err != nil { + return appendMessageResult{}, err + } + } + active := sm.activeSegmentLocked(session) + if active == nil { + return appendMessageResult{}, fmt.Errorf("active session segment unavailable") + } + offset, err := jsonlog.AppendLine(filepath.Join(sm.storage, active.Name), toOpenClawMessageEvent(msg)) + if err != nil { + return appendMessageResult{}, err + } + if active.FirstSeq == 0 { + active.FirstSeq = session.nextSeq + } + active.LastSeq = session.nextSeq + active.MessageCount++ + active.LastOffset = offset + active.UpdatedAt = time.Now().UnixMilli() + sm.appendIndexLocked(session, msg, active.Name, active.LastOffset) + session.nextSeq++ + return result, sm.persistSidecarsLocked(session) +} + +func (sm *SessionManager) ensureActiveSegmentLocked(session *Session) error { + if session == nil { + return nil + } + for i := range session.segments { + if strings.EqualFold(session.segments[i].Name, activeSegmentFilename(session.Key)) { + return nil + } + } + session.segments = append(session.segments, sessionSegmentMeta{ + Name: activeSegmentFilename(session.Key), + Archived: false, + LastOffset: 0, + UpdatedAt: time.Now().UnixMilli(), + }) + return nil +} + +func (sm *SessionManager) activeSegmentLocked(session *Session) *sessionSegmentMeta { + if session == nil { + return nil + } + for i := range session.segments { + if strings.EqualFold(session.segments[i].Name, activeSegmentFilename(session.Key)) { + return &session.segments[i] + } + } + return nil +} + +func (sm *SessionManager) shouldRolloverLocked(session *Session) bool { + active := sm.activeSegmentLocked(session) + if active == nil { + return false + } + if sm.segmentMaxMessages > 0 && active.MessageCount >= sm.segmentMaxMessages { + return true + } + if sm.segmentMaxBytes > 0 && active.LastOffset >= sm.segmentMaxBytes { + return true + } + return false +} + +func (sm *SessionManager) rolloverLocked(session *Session) error { + active := sm.activeSegmentLocked(session) + if active == nil || active.MessageCount == 0 { + return nil + } + nextSeq := 1 + for _, segment := range session.segments { + if seq, ok := parseArchiveSegmentFilename(segment.Name, session.Key); ok && seq >= nextSeq { + nextSeq = seq + 1 + } + } + archiveName := archivedSegmentFilename(session.Key, nextSeq) + if err := os.Rename(filepath.Join(sm.storage, active.Name), filepath.Join(sm.storage, archiveName)); err != nil { + return err + } + active.Name = archiveName + active.Archived = true + if err := sm.ensureActiveSegmentLocked(session); err != nil { + return err + } + newActive := sm.activeSegmentLocked(session) + if newActive != nil { + newActive.Archived = false + newActive.FirstSeq = 0 + newActive.LastSeq = 0 + newActive.MessageCount = 0 + newActive.LastOffset = 0 + newActive.UpdatedAt = time.Now().UnixMilli() + } + rebuilt, err := sm.rebuildSidecars(session.Key, &sessionMetaFile{ + Summary: session.Summary, + CompactionCount: session.CompactionCount, + LastLanguage: session.LastLanguage, + PreferredLanguage: session.PreferredLanguage, + }) + if err != nil { + return err + } + if rebuilt != nil { + session.segments = append([]sessionSegmentMeta(nil), rebuilt.Segments...) + session.nextSeq = maxInt(rebuilt.NextSeq, session.nextSeq) + } + return nil +} + +func (sm *SessionManager) persistSidecarsLocked(session *Session) error { + if sm.storage == "" || session == nil { + return nil + } + meta := sessionMetaFile{ + Version: 1, + SessionKey: session.Key, + SessionID: firstNonEmpty(session.SessionID, deriveSessionID(session.Key)), + Kind: firstNonEmpty(session.Kind, detectSessionKind(session.Key)), + Summary: strings.TrimSpace(session.Summary), + CompactionCount: session.CompactionCount, + LastLanguage: strings.TrimSpace(session.LastLanguage), + PreferredLanguage: strings.TrimSpace(session.PreferredLanguage), + MessageCount: sessionMessageCount(session), + CreatedAt: session.Created.UnixMilli(), + UpdatedAt: session.Updated.UnixMilli(), + NextSeq: maxInt(session.nextSeq, sessionMessageCount(session)+1), + Segments: append([]sessionSegmentMeta(nil), session.segments...), + } + if session.index == nil { + index, err := sm.buildIndexForSessionLocked(session, meta.NextSeq-1) + if err != nil { + return err + } + session.index = index + } + return sm.writeSidecarFiles(session.Key, &meta, session.index) +} + +func (sm *SessionManager) buildIndexForSessionLocked(session *Session, seqEnd int) (*sessionIndexFile, error) { + messages, err := sm.loadMessagesForSegments(session.segments) + if err != nil { + return nil, err + } + index := &sessionIndexFile{ + Version: 1, + SessionKey: session.Key, + LastSeq: 0, + LastOffset: 0, + Segment: "", + UpdatedAt: time.Now().UnixMilli(), + Tokens: map[string][]sessionIndexRef{}, + } + for i, msg := range messages { + ref := sessionIndexRef{ + Seq: i + 1, + Role: strings.ToLower(strings.TrimSpace(msg.Role)), + Segment: segmentNameForSeq(session.segments, i+1), + Snippet: messageSnippet(msg.Content), + } + appendTokens(index.Tokens, tokenizeIndexText(msg.Content), ref) + } + index.LastSeq = seqEnd + if active := sm.activeSegmentLocked(session); active != nil { + index.LastOffset = active.LastOffset + index.Segment = active.Name + } + return index, nil +} + +func segmentNameForSeq(segments []sessionSegmentMeta, seq int) string { + for _, segment := range segments { + if seq >= segment.FirstSeq && seq <= segment.LastSeq { + return segment.Name + } + } + return "" +} + +func sessionMessageCount(session *Session) int { + count := 0 + for _, segment := range session.segments { + count += segment.MessageCount + } + if count == 0 && len(session.Messages) > 0 { + count = len(session.Messages) + } + return count +} + +func (sm *SessionManager) writeSidecarFiles(key string, meta *sessionMetaFile, index *sessionIndexFile) error { + if err := jsonlog.WriteJSON(sm.sessionMetaPath(key), meta); err != nil { + return err + } + return jsonlog.WriteJSON(sm.sessionIndexPath(key), index) +} + +func (sm *SessionManager) readIndexFile(path string) (*sessionIndexFile, error) { + var index sessionIndexFile + if err := jsonlog.ReadJSON(path, &index); err != nil { + return nil, err + } + if index.Version <= 0 { + return nil, fmt.Errorf("invalid index version") + } + return &index, nil +} + +func (sm *SessionManager) searchSessionByScan(key, kind, summary string, updated time.Time, terms []string) (SessionSearchResult, bool) { + session := SessionSearchResult{ + Key: key, + Kind: kind, + Summary: summary, + UpdatedAt: updated, + } + all, err := sm.loadMessagesForSegments(sm.sessionSegments(key)) + if err != nil { + return SessionSearchResult{}, false + } + type scored struct { + score int + snippet SessionSearchSnippet + } + matches := make([]scored, 0) + for idx, msg := range all { + score := messageMatchScore(msg, terms) + if score == 0 { + continue + } + matches = append(matches, scored{ + score: score, + snippet: SessionSearchSnippet{ + Seq: idx + 1, + Role: strings.ToLower(strings.TrimSpace(msg.Role)), + Content: messageSnippet(msg.Content), + }, + }) + } + if len(matches) == 0 { + return SessionSearchResult{}, false + } + sort.Slice(matches, func(i, j int) bool { + if matches[i].score != matches[j].score { + return matches[i].score > matches[j].score + } + return matches[i].snippet.Seq > matches[j].snippet.Seq + }) + for i, item := range matches { + if i >= maxSearchSnippetsPerSession { + break + } + session.Score += item.score + session.Snippets = append(session.Snippets, item.snippet) + } + return session, true +} + +func (sm *SessionManager) sessionSegments(key string) []sessionSegmentMeta { + sm.mu.RLock() + session := sm.sessions[key] + sm.mu.RUnlock() + if session == nil { + return nil + } + session.mu.RLock() + defer session.mu.RUnlock() + return append([]sessionSegmentMeta(nil), session.segments...) +} + +func searchSessionIndex(key, kind, summary string, updated time.Time, terms []string, index *sessionIndexFile) (SessionSearchResult, bool) { + type aggregate struct { + score int + ref sessionIndexRef + } + hits := map[int]*aggregate{} + for _, term := range terms { + for _, ref := range index.Tokens[term] { + item := hits[ref.Seq] + if item == nil { + item = &aggregate{ref: ref} + hits[ref.Seq] = item + } + item.score++ + } + } + if len(hits) == 0 { + return SessionSearchResult{}, false + } + aggs := make([]aggregate, 0, len(hits)) + for _, item := range hits { + aggs = append(aggs, *item) + } + sort.Slice(aggs, func(i, j int) bool { + if aggs[i].score != aggs[j].score { + return aggs[i].score > aggs[j].score + } + return aggs[i].ref.Seq > aggs[j].ref.Seq + }) + result := SessionSearchResult{ + Key: key, + Kind: kind, + Summary: summary, + UpdatedAt: updated, + } + for i, agg := range aggs { + if i >= maxSearchSnippetsPerSession { + break + } + result.Score += agg.score + result.Snippets = append(result.Snippets, SessionSearchSnippet{ + Seq: agg.ref.Seq, + Role: agg.ref.Role, + Segment: agg.ref.Segment, + Content: agg.ref.Snippet, + }) + } + return result, true +} + +func appendTokens(dst map[string][]sessionIndexRef, tokens []string, ref sessionIndexRef) { + for _, token := range tokens { + items := dst[token] + if len(items) >= maxTokenRefsPerSessionToken { + continue + } + items = append(items, ref) + dst[token] = items + } +} + +func (sm *SessionManager) appendIndexLocked(session *Session, msg providers.Message, segment string, offset int64) { + if session == nil { + return + } + if session.index == nil { + session.index = &sessionIndexFile{ + Version: 1, + SessionKey: session.Key, + Tokens: map[string][]sessionIndexRef{}, + } + } + if session.index.Tokens == nil { + session.index.Tokens = map[string][]sessionIndexRef{} + } + ref := sessionIndexRef{ + Seq: session.nextSeq, + Role: strings.ToLower(strings.TrimSpace(msg.Role)), + Segment: segment, + Snippet: messageSnippet(msg.Content), + } + appendTokens(session.index.Tokens, tokenizeIndexText(msg.Content), ref) + session.index.LastSeq = session.nextSeq + session.index.LastOffset = offset + session.index.Segment = segment + session.index.UpdatedAt = time.Now().UnixMilli() +} + +func messageMatchScore(msg providers.Message, terms []string) int { + if len(terms) == 0 { + return 0 + } + content := strings.ToLower(msg.Content) + score := 0 + for _, term := range terms { + if strings.Contains(content, term) { + score++ + } + } + return score +} + +func messageSnippet(content string) string { + content = strings.TrimSpace(strings.ReplaceAll(content, "\n", " ")) + if len(content) <= 180 { + return content + } + return content[:180] + "..." +} + +func tokenizeIndexText(text string) []string { + return tokenizeSearchText(text, false) +} + +func tokenizeQueryText(text string) []string { + return tokenizeSearchText(text, true) +} + +func tokenizeSearchText(text string, includeSingleHan bool) []string { + text = strings.ToLower(strings.TrimSpace(text)) + if text == "" { + return nil + } + out := make([]string, 0, 16) + seen := map[string]struct{}{} + var asciiBuf strings.Builder + flushASCII := func() { + if asciiBuf.Len() == 0 { + return + } + fields := strings.FieldsFunc(asciiBuf.String(), func(r rune) bool { + return !unicode.IsLetter(r) && !unicode.IsNumber(r) + }) + for _, field := range fields { + field = strings.TrimSpace(field) + if len(field) < 2 { + continue + } + if _, ok := seen[field]; ok { + continue + } + seen[field] = struct{}{} + out = append(out, field) + } + asciiBuf.Reset() + } + var hanRunes []rune + flushHan := func() { + if len(hanRunes) == 0 { + return + } + if includeSingleHan && len(hanRunes) == 1 { + token := string(hanRunes[0]) + if _, ok := seen[token]; !ok { + seen[token] = struct{}{} + out = append(out, token) + } + } + if len(hanRunes) >= 2 { + for i := 0; i < len(hanRunes)-1; i++ { + token := string(hanRunes[i : i+2]) + if _, ok := seen[token]; ok { + continue + } + seen[token] = struct{}{} + out = append(out, token) + } + } + hanRunes = hanRunes[:0] + } + for _, r := range text { + switch { + case unicode.Is(unicode.Han, r): + flushASCII() + hanRunes = append(hanRunes, r) + case unicode.IsLetter(r) || unicode.IsNumber(r): + flushHan() + asciiBuf.WriteRune(r) + default: + flushASCII() + flushHan() + } + } + flushASCII() + flushHan() + return out +} + +func (sm *SessionManager) discoverSessionKeys() ([]string, error) { + if sm.storage == "" { + return nil, nil + } + entries, err := os.ReadDir(sm.storage) + if err != nil { + return nil, err + } + keys := map[string]struct{}{} + for _, entry := range entries { + if entry.IsDir() { + continue + } + if key, ok := sessionKeyFromFilename(entry.Name()); ok { + keys[key] = struct{}{} + } + } + out := make([]string, 0, len(keys)) + for key := range keys { + out = append(out, key) + } + sort.Strings(out) + return out, nil +} + +func (sm *SessionManager) discoverSessionSegments(key string) ([]string, error) { + entries, err := os.ReadDir(sm.storage) + if err != nil { + return nil, err + } + names := make([]string, 0) + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + switch { + case name == activeSegmentFilename(key): + names = append(names, name) + case name == legacySegmentFilename(key): + names = append(names, name) + default: + if _, ok := parseArchiveSegmentFilename(name, key); ok { + names = append(names, name) + } + } + } + sort.Slice(names, func(i, j int) bool { + return compareSegmentNames(key, names[i], names[j]) + }) + return names, nil +} + +func compareSegmentNames(key, left, right string) bool { + if left == right { + return false + } + if left == legacySegmentFilename(key) { + return true + } + if right == legacySegmentFilename(key) { + return false + } + if left == activeSegmentFilename(key) { + return false + } + if right == activeSegmentFilename(key) { + return true + } + li, _ := parseArchiveSegmentFilename(left, key) + ri, _ := parseArchiveSegmentFilename(right, key) + return li < ri +} + +func sessionKeyFromFilename(name string) (string, bool) { + switch { + case strings.HasSuffix(name, ".meta.json"): + return strings.TrimSuffix(name, ".meta.json"), true + case strings.HasSuffix(name, ".index.json"): + return strings.TrimSuffix(name, ".index.json"), true + case strings.HasSuffix(name, ".active.jsonl"): + return strings.TrimSuffix(name, ".active.jsonl"), true + case strings.HasSuffix(name, ".jsonl") && !strings.Contains(name, ".deleted."): + if m := archiveSegmentRe.FindStringSubmatch(name); len(m) == 3 { + return m[1], true + } + return strings.TrimSuffix(name, ".jsonl"), true + default: + return "", false + } +} + +func parseArchiveSegmentFilename(name, key string) (int, bool) { + m := archiveSegmentRe.FindStringSubmatch(name) + if len(m) != 3 || m[1] != key { + return 0, false + } + n, err := strconv.Atoi(m[2]) + if err != nil { + return 0, false + } + return n, true +} + +func activeSegmentFilename(key string) string { return key + ".active.jsonl" } +func legacySegmentFilename(key string) string { return key + ".jsonl" } +func archivedSegmentFilename(key string, seq int) string { + return fmt.Sprintf("%s.%04d.jsonl", key, seq) +} + +func (sm *SessionManager) sessionMetaPath(key string) string { + return filepath.Join(sm.storage, key+".meta.json") +} +func (sm *SessionManager) sessionIndexPath(key string) string { + return filepath.Join(sm.storage, key+".index.json") +} + +func readPositiveIntEnv(name string, fallback int) int { + if value := strings.TrimSpace(os.Getenv(name)); value != "" { + if n, err := strconv.Atoi(value); err == nil && n > 0 { + return n + } + } + return fallback +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return value + } + } + return "" +} + +func seedValue(seed *sessionMetaFile, get func(*sessionMetaFile) string) string { + if seed == nil { + return "" + } + return get(seed) +} + +func seedInt(seed *sessionMetaFile, get func(*sessionMetaFile) int) int { + if seed == nil { + return 0 + } + return get(seed) +} + +func maxInt(values ...int) int { + best := 0 + for _, value := range values { + if value > best { + best = value + } + } + return best +} + +func cloneSessionIndex(index *sessionIndexFile) *sessionIndexFile { + if index == nil { + return nil + } + out := &sessionIndexFile{ + Version: index.Version, + SessionKey: index.SessionKey, + LastSeq: index.LastSeq, + LastOffset: index.LastOffset, + Segment: index.Segment, + UpdatedAt: index.UpdatedAt, + Tokens: make(map[string][]sessionIndexRef, len(index.Tokens)), + } + for key, refs := range index.Tokens { + out.Tokens[key] = append([]sessionIndexRef(nil), refs...) + } + return out +} + +func computeHistorySeqWindow(segments []sessionSegmentMeta, around, before, after, limit int) (int, int) { + total := 0 + for _, segment := range segments { + if segment.LastSeq > total { + total = segment.LastSeq + } + } + if total <= 0 { + return 1, 0 + } + if limit <= 0 { + limit = 50 + } + start := 1 + end := total + if around > 0 { + half := limit / 2 + if half < 1 { + half = 1 + } + start = around - half + end = around + half + if start < 1 { + start = 1 + } + if end > total { + end = total + } + } else { + if after > 0 { + start = after + 1 + } + if before > 0 { + end = before - 1 + } + } + if start < 1 { + start = 1 + } + if end > total { + end = total + } + if end < start { + return 1, 0 + } + if end-start+1 > limit { + start = end - limit + 1 + if start < 1 { + start = 1 + } + } + return start, end +} diff --git a/pkg/session/manager_test.go b/pkg/session/manager_test.go index a28af3d..2fc2b88 100644 --- a/pkg/session/manager_test.go +++ b/pkg/session/manager_test.go @@ -5,6 +5,10 @@ import ( "path/filepath" "strings" "testing" + "time" + + "github.com/YspCoder/clawgo/pkg/jsonlog" + "github.com/YspCoder/clawgo/pkg/providers" ) func TestLoadSessionsReturnsScannerErrorForOversizedLine(t *testing.T) { @@ -38,3 +42,197 @@ func TestFromJSONLLineParsesOpenClawToolResult(t *testing.T) { } } +func TestSessionManagerWritesSidecarsAndSearches(t *testing.T) { + t.Parallel() + + storage := t.TempDir() + sm := NewSessionManager(storage) + key := "cli:default" + + sm.AddMessage(key, "user", "deploy project alpha") + sm.AddMessage(key, "assistant", "deployment failed with timeout after contacting api gateway") + + for _, name := range []string{ + key + ".active.jsonl", + key + ".meta.json", + key + ".index.json", + } { + if _, err := os.Stat(filepath.Join(storage, name)); err != nil { + t.Fatalf("expected artifact %s: %v", name, err) + } + } + + results := sm.Search("deploy timeout", nil, "", 5) + if len(results) != 1 { + t.Fatalf("expected one search result, got %#v", results) + } + if results[0].Key != key || len(results[0].Snippets) == 0 { + t.Fatalf("unexpected search result: %#v", results[0]) + } +} + +func TestSessionManagerRebuildsMissingSidecarsFromJSONL(t *testing.T) { + t.Parallel() + + storage := t.TempDir() + sm := NewSessionManager(storage) + key := "cli:summary" + + sm.AddMessage(key, "user", "remember previous deploy steps") + sm.SetSummary(key, "Key Facts\n- Previous deploy steps were discussed.") + + if err := os.Remove(filepath.Join(storage, key+".meta.json")); err != nil { + t.Fatalf("remove meta: %v", err) + } + if err := os.Remove(filepath.Join(storage, key+".index.json")); err != nil { + t.Fatalf("remove index: %v", err) + } + + reloaded := NewSessionManager(storage) + if got := reloaded.GetSummary(key); !strings.Contains(got, "Previous deploy steps") { + t.Fatalf("expected summary recovered from fallback index, got %q", got) + } + results := reloaded.Search("deploy", nil, "", 5) + if len(results) != 1 || results[0].Key != key { + t.Fatalf("expected rebuilt search result, got %#v", results) + } +} + +func TestSessionManagerRollsOverActiveSegment(t *testing.T) { + t.Setenv("CLAWGO_SESSION_SEGMENT_MAX_MESSAGES", "2") + storage := t.TempDir() + sm := NewSessionManager(storage) + key := "cli:rollover" + + sm.AddMessage(key, "user", "one") + sm.AddMessage(key, "assistant", "two") + sm.AddMessage(key, "user", "three") + + if _, err := os.Stat(filepath.Join(storage, key+".0001.jsonl")); err != nil { + t.Fatalf("expected archived segment: %v", err) + } + if _, err := os.Stat(filepath.Join(storage, key+".active.jsonl")); err != nil { + t.Fatalf("expected new active segment: %v", err) + } + history := sm.GetHistory(key) + if len(history) != 3 { + t.Fatalf("expected full history across segments, got %d", len(history)) + } +} + +func TestSessionManagerHistoryWindowAndIncrementalIndex(t *testing.T) { + t.Setenv("CLAWGO_SESSION_SEGMENT_MAX_MESSAGES", "2") + storage := t.TempDir() + sm := NewSessionManager(storage) + key := "cli:window" + + sm.AddMessage(key, "user", "one") + sm.AddMessage(key, "assistant", "two") + sm.AddMessage(key, "user", "three") + sm.AddMessage(key, "assistant", "four") + + window := sm.GetHistoryWindow(key, 0, 0, 2, 2) + if len(window) != 2 || window[0].Content != "three" || window[1].Content != "four" { + t.Fatalf("unexpected history window: %#v", window) + } + + var index sessionIndexFile + if err := jsonlog.ReadJSON(filepath.Join(storage, key+".index.json"), &index); err != nil { + t.Fatalf("read index: %v", err) + } + size, err := jsonlog.FileSize(filepath.Join(storage, key+".active.jsonl")) + if err != nil { + t.Fatalf("file size: %v", err) + } + if index.LastSeq != 4 { + t.Fatalf("expected last seq 4, got %d", index.LastSeq) + } + if index.LastOffset != size { + t.Fatalf("expected last offset %d, got %d", size, index.LastOffset) + } + if index.Segment != key+".active.jsonl" { + t.Fatalf("unexpected index segment %q", index.Segment) + } +} + +func TestSessionManagerSearchSupportsChineseBigrams(t *testing.T) { + t.Parallel() + + storage := t.TempDir() + sm := NewSessionManager(storage) + key := "cli:zh" + + sm.AddMessage(key, "user", "之前讨论过发布回滚方案") + sm.AddMessage(key, "assistant", "回滚需要先确认数据库版本") + + results := sm.Search("回滚方案", nil, "", 5) + if len(results) != 1 || results[0].Key != key { + t.Fatalf("expected chinese query to hit sidecar index, got %#v", results) + } + + if err := os.Remove(filepath.Join(storage, key+".index.json")); err != nil { + t.Fatalf("remove index: %v", err) + } + reloaded := NewSessionManager(storage) + results = reloaded.Search("回滚方案", nil, "", 5) + if len(results) != 1 || results[0].Key != key { + t.Fatalf("expected chinese query to hit scan fallback, got %#v", results) + } +} + +func TestApplyCompactionIfUnchangedRejectsChangedSession(t *testing.T) { + t.Parallel() + + sm := NewSessionManager(t.TempDir()) + key := "cli:guard" + sm.AddMessage(key, "user", "one") + sm.AddMessage(key, "assistant", "two") + snapshot := sm.CompactionSnapshot(key) + sm.AddMessage(key, "user", "three") + + applied := sm.ApplyCompactionIfUnchanged(key, snapshot.NextSeq, snapshot.Summary, []providers.Message{{Role: "assistant", Content: "two"}}, "Key Facts\n- compacted") + if applied { + t.Fatal("expected stale compaction application to be rejected") + } + history := sm.GetPromptHistory(key) + if len(history) != 3 || history[2].Content != "three" { + t.Fatalf("expected newer message to remain, got %#v", history) + } +} + +func TestSessionManagerAppendDoesNotRewriteSessionsIndex(t *testing.T) { + t.Parallel() + + storage := t.TempDir() + sm := NewSessionManager(storage) + key := "cli:index" + + sm.AddMessage(key, "user", "first") + indexPath := filepath.Join(storage, "sessions.json") + before, err := os.ReadFile(indexPath) + if err != nil { + t.Fatalf("read sessions index: %v", err) + } + statBefore, err := os.Stat(indexPath) + if err != nil { + t.Fatalf("stat sessions index: %v", err) + } + + time.Sleep(10 * time.Millisecond) + sm.AddMessage(key, "assistant", "second") + + after, err := os.ReadFile(indexPath) + if err != nil { + t.Fatalf("read sessions index after append: %v", err) + } + statAfter, err := os.Stat(indexPath) + if err != nil { + t.Fatalf("stat sessions index after append: %v", err) + } + if string(before) != string(after) { + t.Fatalf("expected sessions.json to stay unchanged for hot-path append") + } + if !statAfter.ModTime().Equal(statBefore.ModTime()) { + t.Fatalf("expected sessions.json mtime to stay unchanged, before=%s after=%s", statBefore.ModTime(), statAfter.ModTime()) + } +} diff --git a/pkg/tools/runtime_types.go b/pkg/tools/runtime_types.go index feac799..a07fba8 100644 --- a/pkg/tools/runtime_types.go +++ b/pkg/tools/runtime_types.go @@ -66,16 +66,17 @@ type RunRecord struct { } type EventRecord struct { - ID string `json:"id,omitempty"` - RunID string `json:"run_id,omitempty"` - RequestID string `json:"request_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - Type string `json:"type"` - Status string `json:"status,omitempty"` - Message string `json:"message,omitempty"` - RetryCount int `json:"retry_count,omitempty"` - Error *RuntimeError `json:"error,omitempty"` - At int64 `json:"ts"` + ID string `json:"id,omitempty"` + RunID string `json:"run_id,omitempty"` + RequestID string `json:"request_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + Type string `json:"type"` + Status string `json:"status,omitempty"` + FailureCode string `json:"failure_code,omitempty"` + Message string `json:"message,omitempty"` + RetryCount int `json:"retry_count,omitempty"` + Error *RuntimeError `json:"error,omitempty"` + At int64 `json:"ts"` } type ArtifactRecord struct { diff --git a/pkg/tools/session_search.go b/pkg/tools/session_search.go new file mode 100644 index 0000000..cf9ace6 --- /dev/null +++ b/pkg/tools/session_search.go @@ -0,0 +1,103 @@ +package tools + +import ( + "context" + "fmt" + "strings" + + "github.com/YspCoder/clawgo/pkg/session" +) + +type SessionSearchTool struct { + manager *session.SessionManager +} + +func NewSessionSearchTool(manager *session.SessionManager) *SessionSearchTool { + return &SessionSearchTool{manager: manager} +} + +func (t *SessionSearchTool) Name() string { return "session_search" } + +func (t *SessionSearchTool) Description() string { + return "Search past session history across JSONL session logs. Use when the user refers to previous conversations, past decisions, or earlier project work." +} + +func (t *SessionSearchTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "query": map[string]interface{}{ + "type": "string", + "description": "Keywords or short phrase to search across past sessions.", + }, + "limit": map[string]interface{}{ + "type": "integer", + "description": "Maximum number of matching sessions to return.", + "default": 5, + }, + "kinds": map[string]interface{}{ + "type": "array", + "description": "Optional session kinds filter, e.g. main, cron, subagent, hook.", + "items": map[string]interface{}{"type": "string"}, + }, + "exclude_current": map[string]interface{}{ + "type": "boolean", + "description": "Exclude the current session from results when session_key is provided.", + "default": true, + }, + "session_key": map[string]interface{}{ + "type": "string", + "description": "Optional current session key for exclude_current behavior.", + }, + }, + "required": []string{"query"}, + } +} + +func (t *SessionSearchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { + _ = ctx + if t == nil || t.manager == nil { + return "", fmt.Errorf("session manager not configured") + } + query := MapStringArg(args, "query") + if query == "" { + return "", fmt.Errorf("query is required") + } + limit := MapIntArg(args, "limit", 5) + kinds := MapStringListArg(args, "kinds") + excludeCurrent, excludeSet := MapBoolArg(args, "exclude_current") + excludeKey := "" + if excludeSet && excludeCurrent { + excludeKey = MapStringArg(args, "session_key") + } + + results := t.manager.Search(query, kinds, excludeKey, limit) + if len(results) == 0 { + return fmt.Sprintf("No past sessions matched %q.", query), nil + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Session search results for %q:\n\n", query)) + for _, item := range results { + sb.WriteString(fmt.Sprintf("- %s kind=%s updated=%s score=%d\n", item.Key, item.Kind, item.UpdatedAt.Format("2006-01-02 15:04:05"), item.Score)) + if summary := strings.TrimSpace(item.Summary); summary != "" { + sb.WriteString(" summary: " + singleLine(summary, 220) + "\n") + } + for _, snippet := range item.Snippets { + line := singleLine(snippet.Content, 220) + sb.WriteString(fmt.Sprintf(" [#%d][%s] %s\n", snippet.Seq, snippet.Role, line)) + } + sb.WriteString("\n") + } + return strings.TrimSpace(sb.String()), nil +} + +func (t *SessionSearchTool) ParallelSafe() bool { return true } + +func singleLine(s string, max int) string { + s = strings.TrimSpace(strings.ReplaceAll(s, "\n", " ")) + if max > 0 && len(s) > max { + return s[:max] + "..." + } + return s +} diff --git a/pkg/tools/spawn.go b/pkg/tools/spawn.go index d3fd356..366cbb9 100644 --- a/pkg/tools/spawn.go +++ b/pkg/tools/spawn.go @@ -61,6 +61,10 @@ func (t *SpawnTool) Parameters() map[string]interface{} { "type": "integer", "description": "Optional per-attempt timeout in seconds.", }, + "max_tool_iterations": map[string]interface{}{ + "type": "integer", + "description": "Optional independent tool-calling iteration budget.", + }, "max_task_chars": map[string]interface{}{ "type": "integer", "description": "Optional task size quota in characters.", @@ -101,6 +105,7 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (s maxRetries := MapIntArg(args, "max_retries", 0) retryBackoff := MapIntArg(args, "retry_backoff_ms", 0) timeoutSec := MapIntArg(args, "timeout_sec", 0) + maxToolIterations := MapIntArg(args, "max_tool_iterations", 0) maxTaskChars := MapIntArg(args, "max_task_chars", 0) maxResultChars := MapIntArg(args, "max_result_chars", 0) if label == "" && role != "" { @@ -129,17 +134,18 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (s } result, err := t.manager.Spawn(ctx, SubagentSpawnOptions{ - Task: task, - Label: label, - Role: role, - AgentID: agentID, - MaxRetries: maxRetries, - RetryBackoff: retryBackoff, - TimeoutSec: timeoutSec, - MaxTaskChars: maxTaskChars, - MaxResultChars: maxResultChars, - OriginChannel: originChannel, - OriginChatID: originChatID, + Task: task, + Label: label, + Role: role, + AgentID: agentID, + MaxRetries: maxRetries, + RetryBackoff: retryBackoff, + TimeoutSec: timeoutSec, + MaxToolIterations: maxToolIterations, + MaxTaskChars: maxTaskChars, + MaxResultChars: maxResultChars, + OriginChannel: originChannel, + OriginChatID: originChatID, }) if err != nil { return "", fmt.Errorf("failed to spawn subagent: %w", err) diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 6854a48..717d97a 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -15,37 +15,42 @@ import ( ) type SubagentRun struct { - ID string `json:"id"` - Task string `json:"task"` - Label string `json:"label"` - Role string `json:"role"` - AgentID string `json:"agent_id"` - Transport string `json:"transport,omitempty"` - ParentAgentID string `json:"parent_agent_id,omitempty"` - NotifyMainPolicy string `json:"notify_main_policy,omitempty"` - SessionKey string `json:"session_key"` - MemoryNS string `json:"memory_ns"` - SystemPromptFile string `json:"system_prompt_file,omitempty"` - ToolAllowlist []string `json:"tool_allowlist,omitempty"` - MaxRetries int `json:"max_retries,omitempty"` - RetryBackoff int `json:"retry_backoff,omitempty"` - TimeoutSec int `json:"timeout_sec,omitempty"` - MaxTaskChars int `json:"max_task_chars,omitempty"` - MaxResultChars int `json:"max_result_chars,omitempty"` - RetryCount int `json:"retry_count,omitempty"` - ThreadID string `json:"thread_id,omitempty"` - CorrelationID string `json:"correlation_id,omitempty"` - ParentRunID string `json:"parent_run_id,omitempty"` - LastMessageID string `json:"last_message_id,omitempty"` - WaitingReply bool `json:"waiting_for_reply,omitempty"` - SharedState map[string]interface{} `json:"shared_state,omitempty"` - OriginChannel string `json:"origin_channel,omitempty"` - OriginChatID string `json:"origin_chat_id,omitempty"` - Status string `json:"status"` - Result string `json:"result,omitempty"` - Steering []string `json:"steering,omitempty"` - Created int64 `json:"created"` - Updated int64 `json:"updated"` + ID string `json:"id"` + Task string `json:"task"` + Label string `json:"label"` + Role string `json:"role"` + AgentID string `json:"agent_id"` + Transport string `json:"transport,omitempty"` + ParentAgentID string `json:"parent_agent_id,omitempty"` + NotifyMainPolicy string `json:"notify_main_policy,omitempty"` + SessionKey string `json:"session_key"` + MemoryNS string `json:"memory_ns"` + SystemPromptFile string `json:"system_prompt_file,omitempty"` + ToolAllowlist []string `json:"tool_allowlist,omitempty"` + MaxRetries int `json:"max_retries,omitempty"` + RetryBackoff int `json:"retry_backoff,omitempty"` + TimeoutSec int `json:"timeout_sec,omitempty"` + MaxToolIterations int `json:"max_tool_iterations,omitempty"` + MaxTaskChars int `json:"max_task_chars,omitempty"` + MaxResultChars int `json:"max_result_chars,omitempty"` + RetryCount int `json:"retry_count,omitempty"` + IterationCount int `json:"iteration_count,omitempty"` + AttemptCount int `json:"attempt_count,omitempty"` + RestartCount int `json:"restart_count,omitempty"` + LastFailureCode string `json:"last_failure_code,omitempty"` + ThreadID string `json:"thread_id,omitempty"` + CorrelationID string `json:"correlation_id,omitempty"` + ParentRunID string `json:"parent_run_id,omitempty"` + LastMessageID string `json:"last_message_id,omitempty"` + WaitingReply bool `json:"waiting_for_reply,omitempty"` + SharedState map[string]interface{} `json:"shared_state,omitempty"` + OriginChannel string `json:"origin_channel,omitempty"` + OriginChatID string `json:"origin_chat_id,omitempty"` + Status string `json:"status"` + Result string `json:"result,omitempty"` + Steering []string `json:"steering,omitempty"` + Created int64 `json:"created"` + Updated int64 `json:"updated"` } type SubagentManager struct { @@ -66,21 +71,22 @@ type SubagentManager struct { } type SubagentSpawnOptions struct { - Task string - Label string - Role string - AgentID string - NotifyMainPolicy string - MaxRetries int - RetryBackoff int - TimeoutSec int - MaxTaskChars int - MaxResultChars int - OriginChannel string - OriginChatID string - ThreadID string - CorrelationID string - ParentRunID string + Task string + Label string + Role string + AgentID string + NotifyMainPolicy string + MaxRetries int + RetryBackoff int + TimeoutSec int + MaxToolIterations int + MaxTaskChars int + MaxResultChars int + OriginChannel string + OriginChatID string + ThreadID string + CorrelationID string + ParentRunID string } func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *bus.MessageBus) *SubagentManager { @@ -173,6 +179,7 @@ func (sm *SubagentManager) spawnRun(ctx context.Context, opts SubagentSpawnOptio maxRetries := 0 retryBackoff := 1000 timeoutSec := 0 + maxToolIterations := 0 maxTaskChars := 0 maxResultChars := 0 if profile == nil && sm.profileStore != nil { @@ -206,6 +213,7 @@ func (sm *SubagentManager) spawnRun(ctx context.Context, opts SubagentSpawnOptio maxRetries = profile.MaxRetries retryBackoff = profile.RetryBackoff timeoutSec = profile.TimeoutSec + maxToolIterations = profile.MaxToolIterations maxTaskChars = profile.MaxTaskChars maxResultChars = profile.MaxResultChars } @@ -218,6 +226,9 @@ func (sm *SubagentManager) spawnRun(ctx context.Context, opts SubagentSpawnOptio if opts.TimeoutSec > 0 { timeoutSec = opts.TimeoutSec } + if opts.MaxToolIterations > 0 { + maxToolIterations = opts.MaxToolIterations + } if opts.MaxTaskChars > 0 { maxTaskChars = opts.MaxTaskChars } @@ -230,6 +241,7 @@ func (sm *SubagentManager) spawnRun(ctx context.Context, opts SubagentSpawnOptio maxRetries = normalizePositiveBound(maxRetries, 0, 8) retryBackoff = normalizePositiveBound(retryBackoff, 500, 120000) timeoutSec = normalizePositiveBound(timeoutSec, 0, 3600) + maxToolIterations = normalizePositiveBound(maxToolIterations, 0, 200) maxTaskChars = normalizePositiveBound(maxTaskChars, 0, 400000) maxResultChars = normalizePositiveBound(maxResultChars, 0, 400000) if role == "" { @@ -270,32 +282,33 @@ func (sm *SubagentManager) spawnRun(ctx context.Context, opts SubagentSpawnOptio } } subagentRun := &SubagentRun{ - ID: runID, - Task: task, - Label: label, - Role: role, - AgentID: agentID, - Transport: transport, - ParentAgentID: parentAgentID, - NotifyMainPolicy: notifyMainPolicy, - SessionKey: sessionKey, - MemoryNS: memoryNS, - SystemPromptFile: systemPromptFile, - ToolAllowlist: toolAllowlist, - MaxRetries: maxRetries, - RetryBackoff: retryBackoff, - TimeoutSec: timeoutSec, - MaxTaskChars: maxTaskChars, - MaxResultChars: maxResultChars, - RetryCount: 0, - ThreadID: threadID, - CorrelationID: correlationID, - ParentRunID: parentRunID, - OriginChannel: originChannel, - OriginChatID: originChatID, - Status: RuntimeStatusRouting, - Created: now, - Updated: now, + ID: runID, + Task: task, + Label: label, + Role: role, + AgentID: agentID, + Transport: transport, + ParentAgentID: parentAgentID, + NotifyMainPolicy: notifyMainPolicy, + SessionKey: sessionKey, + MemoryNS: memoryNS, + SystemPromptFile: systemPromptFile, + ToolAllowlist: toolAllowlist, + MaxRetries: maxRetries, + RetryBackoff: retryBackoff, + TimeoutSec: timeoutSec, + MaxToolIterations: maxToolIterations, + MaxTaskChars: maxTaskChars, + MaxResultChars: maxResultChars, + RetryCount: 0, + ThreadID: threadID, + CorrelationID: correlationID, + ParentRunID: parentRunID, + OriginChannel: originChannel, + OriginChatID: originChatID, + Status: RuntimeStatusRouting, + Created: now, + Updated: now, } taskCtx, cancel := context.WithCancel(ctx) sm.runs[runID] = subagentRun @@ -335,6 +348,7 @@ func (sm *SubagentManager) runSubagent(ctx context.Context, run *SubagentRun) { sm.mu.Lock() if runErr != nil { run.Status = RuntimeStatusFailed + run.LastFailureCode = classifySubagentFailureCode(runErr) run.Result = fmt.Sprintf("Error: %v", runErr) run.Result = applySubagentResultQuota(run.Result, run.MaxResultChars) run.Updated = time.Now().UnixMilli() @@ -354,6 +368,7 @@ func (sm *SubagentManager) runSubagent(ctx context.Context, run *SubagentRun) { sm.notifyRunWaitersLocked(run.ID) } else { run.Status = RuntimeStatusCompleted + run.LastFailureCode = "" run.Result = applySubagentResultQuota(result, run.MaxResultChars) run.Updated = time.Now().UnixMilli() run.WaitingReply = false @@ -383,16 +398,17 @@ func (sm *SubagentManager) runSubagent(ctx context.Context, run *SubagentRun) { SessionKey: run.SessionKey, Content: announceContent, Metadata: map[string]string{ - "trigger": "subagent", - "subagent_id": run.ID, - "agent_id": run.AgentID, - "role": run.Role, - "session_key": run.SessionKey, - "memory_ns": run.MemoryNS, - "retry_count": fmt.Sprintf("%d", run.RetryCount), - "timeout_sec": fmt.Sprintf("%d", run.TimeoutSec), - "status": run.Status, - "notify_reason": notifyReason, + "trigger": "subagent", + "subagent_id": run.ID, + "agent_id": run.AgentID, + "role": run.Role, + "session_key": run.SessionKey, + "memory_ns": run.MemoryNS, + "retry_count": fmt.Sprintf("%d", run.RetryCount), + "iteration_count": fmt.Sprintf("%d", run.IterationCount), + "timeout_sec": fmt.Sprintf("%d", run.TimeoutSec), + "status": run.Status, + "notify_reason": notifyReason, }, }) } @@ -504,6 +520,10 @@ func (sm *SubagentManager) runWithRetry(ctx context.Context, run *SubagentRun) ( var lastErr error for attempt := 0; attempt <= maxRetries; attempt++ { + if remaining := sm.remainingIterations(run); remaining == 0 && run.MaxToolIterations > 0 { + run.LastFailureCode = "retry_limit" + return "", fmt.Errorf("subagent iteration budget exhausted") + } runCtx := ctx var cancel context.CancelFunc if timeoutSec > 0 { @@ -516,18 +536,20 @@ func (sm *SubagentManager) runWithRetry(ctx context.Context, run *SubagentRun) ( if err == nil { sm.mu.Lock() run.RetryCount = attempt + run.LastFailureCode = "" run.Updated = time.Now().UnixMilli() sm.persistRunLocked(run, "attempt_succeeded", "") sm.mu.Unlock() return result, nil } lastErr = err + run.LastFailureCode = classifySubagentFailureCode(err) sm.mu.Lock() run.RetryCount = attempt run.Updated = time.Now().UnixMilli() sm.persistRunLocked(run, "attempt_failed", err.Error()) sm.mu.Unlock() - if attempt >= maxRetries { + if attempt >= maxRetries || !shouldRetrySubagentError(err) { break } select { @@ -547,8 +569,14 @@ func (sm *SubagentManager) executeRunOnce(ctx context.Context, run *SubagentRun) return "", fmt.Errorf("subagent run is nil") } pending, consumedIDs := sm.consumeThreadInbox(run) + stats := &SubagentExecutionStats{} + ctx = WithSubagentExecutionStats(ctx, stats) if sm.runFunc != nil { + if remaining := sm.remainingIterations(run); remaining > 0 { + ctx = WithSubagentIterationBudget(ctx, remaining) + } result, err := sm.runFunc(ctx, run) + sm.applyExecutionStats(run, stats) if err != nil { sm.restoreMessageStatuses(consumedIDs) } else { @@ -582,6 +610,8 @@ func (sm *SubagentManager) executeRunOnce(ctx context.Context, run *SubagentRun) response, err := sm.provider.Chat(ctx, messages, nil, sm.provider.GetDefaultModel(), map[string]interface{}{ "max_tokens": 4096, }) + stats.Attempts++ + sm.applyExecutionStats(run, stats) if err != nil { sm.restoreMessageStatuses(consumedIDs) return "", err @@ -723,15 +753,16 @@ func (sm *SubagentManager) RuntimeSnapshot(limit int) RuntimeSnapshot { if evts, err := sm.Events(run.ID, limit); err == nil { for _, evt := range evts { snapshot.Events = append(snapshot.Events, EventRecord{ - ID: EventRecordID(evt.RunID, evt.Type, evt.At), - RunID: evt.RunID, - RequestID: evt.RunID, - AgentID: evt.AgentID, - Type: evt.Type, - Status: evt.Status, - Message: evt.Message, - RetryCount: evt.RetryCount, - At: evt.At, + ID: EventRecordID(evt.RunID, evt.Type, evt.At), + RunID: evt.RunID, + RequestID: evt.RunID, + AgentID: evt.AgentID, + Type: evt.Type, + Status: evt.Status, + FailureCode: evt.FailureCode, + Message: evt.Message, + RetryCount: evt.RetryCount, + At: evt.At, }) } } @@ -823,6 +854,71 @@ func applySubagentResultQuota(result string, maxChars int) string { return strings.TrimSpace(trimmed) + suffix } +func (sm *SubagentManager) remainingIterations(run *SubagentRun) int { + if run == nil || run.MaxToolIterations <= 0 { + return 0 + } + remaining := run.MaxToolIterations - run.IterationCount + if remaining < 0 { + return 0 + } + return remaining +} + +func (sm *SubagentManager) applyExecutionStats(run *SubagentRun, stats *SubagentExecutionStats) { + if run == nil || stats == nil { + return + } + run.IterationCount += stats.Iterations + run.AttemptCount += stats.Attempts + run.RestartCount += stats.Restarts + if strings.TrimSpace(stats.FailureCode) != "" { + run.LastFailureCode = strings.TrimSpace(stats.FailureCode) + } +} + +func shouldRetrySubagentError(err error) bool { + code := classifySubagentFailureCode(err) + switch code { + case "", "timeout", "stream_failed", "stream_stale", "context_compacted": + return true + case "continuation_exhausted", "retry_limit": + return false + default: + return !errors.Is(err, context.Canceled) + } +} + +func classifySubagentFailureCode(err error) string { + if err == nil { + return "" + } + if errors.Is(err, context.DeadlineExceeded) { + return "timeout" + } + var execErr *providers.ProviderExecutionError + if errors.As(err, &execErr) && execErr != nil { + if strings.TrimSpace(execErr.Code) != "" { + return strings.TrimSpace(execErr.Code) + } + } + lower := strings.ToLower(strings.TrimSpace(err.Error())) + switch { + case strings.Contains(lower, "max tool iterations"), strings.Contains(lower, "iteration budget exhausted"): + return "retry_limit" + case strings.Contains(lower, "stream stale"): + return "stream_stale" + case strings.Contains(lower, "stream failed"): + return "stream_failed" + case strings.Contains(lower, "continuation exhausted"), strings.Contains(lower, "thinking budget exhausted"): + return "continuation_exhausted" + case strings.Contains(lower, "compaction"): + return "context_compacted" + default: + return "" + } +} + func normalizeSubagentIdentifier(in string) string { in = strings.TrimSpace(strings.ToLower(in)) if in == "" { @@ -867,13 +963,14 @@ func (sm *SubagentManager) persistRunLocked(run *SubagentRun, eventType, message cp := cloneSubagentRun(run) _ = sm.runStore.AppendRun(cp) _ = sm.runStore.AppendEvent(SubagentRunEvent{ - RunID: cp.ID, - AgentID: cp.AgentID, - Type: strings.TrimSpace(eventType), - Status: cp.Status, - Message: strings.TrimSpace(message), - RetryCount: cp.RetryCount, - At: cp.Updated, + RunID: cp.ID, + AgentID: cp.AgentID, + Type: strings.TrimSpace(eventType), + Status: cp.Status, + FailureCode: strings.TrimSpace(cp.LastFailureCode), + Message: strings.TrimSpace(message), + RetryCount: cp.RetryCount, + At: cp.Updated, }) } diff --git a/pkg/tools/subagent_budget_test.go b/pkg/tools/subagent_budget_test.go new file mode 100644 index 0000000..6ce147d --- /dev/null +++ b/pkg/tools/subagent_budget_test.go @@ -0,0 +1,123 @@ +package tools + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/YspCoder/clawgo/pkg/providers" +) + +func TestSubagentRunPreservesRemainingIterationBudgetAcrossRetry(t *testing.T) { + t.Parallel() + + manager := NewSubagentManager(nil, t.TempDir(), nil) + attempts := 0 + manager.SetRunFunc(func(ctx context.Context, run *SubagentRun) (string, error) { + attempts++ + budget, ok := SubagentIterationBudget(ctx) + if !ok { + t.Fatal("expected subagent iteration budget in context") + } + if attempts == 1 { + if budget != 3 { + t.Fatalf("expected first attempt budget 3, got %d", budget) + } + RecordSubagentExecutionStats(ctx, SubagentExecutionStats{ + Iterations: 2, + Attempts: 1, + FailureCode: "stream_failed", + }) + return "", providers.NewProviderExecutionError("stream_failed", "stream failed", "stream", true, "test") + } + if budget != 1 { + t.Fatalf("expected retry to inherit remaining budget 1, got %d", budget) + } + RecordSubagentExecutionStats(ctx, SubagentExecutionStats{ + Iterations: 1, + Attempts: 1, + }) + return "done", nil + }) + + run, err := manager.SpawnRun(context.Background(), SubagentSpawnOptions{ + Task: "finish task", + AgentID: "tester", + MaxRetries: 1, + RetryBackoff: 1, + MaxToolIterations: 3, + }) + if err != nil { + t.Fatalf("spawn run: %v", err) + } + + waitCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + finalRun, _, err := manager.waitRun(waitCtx, run.ID) + if err != nil { + t.Fatalf("wait run: %v", err) + } + if finalRun.Status != RuntimeStatusCompleted { + t.Fatalf("expected completed run, got %+v", finalRun) + } + if finalRun.IterationCount != 3 { + t.Fatalf("expected 3 consumed iterations, got %d", finalRun.IterationCount) + } + if finalRun.AttemptCount != 2 { + t.Fatalf("expected 2 attempts, got %d", finalRun.AttemptCount) + } + events, err := manager.Events(run.ID, 10) + if err != nil { + t.Fatalf("events: %v", err) + } + if len(events) == 0 { + t.Fatal("expected persisted events") + } + foundFailure := false + for _, evt := range events { + if evt.Type == "attempt_failed" && evt.FailureCode == "stream_failed" { + foundFailure = true + } + } + if !foundFailure { + t.Fatalf("expected stream_failed event, got %#v", events) + } +} + +func TestSubagentRunStopsWhenIterationBudgetExhausted(t *testing.T) { + t.Parallel() + + manager := NewSubagentManager(nil, t.TempDir(), nil) + manager.SetRunFunc(func(ctx context.Context, run *SubagentRun) (string, error) { + RecordSubagentExecutionStats(ctx, SubagentExecutionStats{Iterations: 2, Attempts: 1}) + return "", errors.New("max tool iterations exceeded") + }) + + run, err := manager.SpawnRun(context.Background(), SubagentSpawnOptions{ + Task: "exhaust budget", + AgentID: "tester", + MaxRetries: 2, + RetryBackoff: 1, + MaxToolIterations: 2, + }) + if err != nil { + t.Fatalf("spawn run: %v", err) + } + + waitCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + finalRun, _, err := manager.waitRun(waitCtx, run.ID) + if err != nil { + t.Fatalf("wait run: %v", err) + } + if finalRun.Status != RuntimeStatusFailed { + t.Fatalf("expected failed run, got %+v", finalRun) + } + if finalRun.LastFailureCode != "retry_limit" { + t.Fatalf("expected retry_limit failure code, got %q", finalRun.LastFailureCode) + } + if finalRun.IterationCount != 2 { + t.Fatalf("expected consumed iterations to stay at 2, got %d", finalRun.IterationCount) + } +} diff --git a/pkg/tools/subagent_mailbox.go b/pkg/tools/subagent_mailbox.go index 9774eeb..21085f9 100644 --- a/pkg/tools/subagent_mailbox.go +++ b/pkg/tools/subagent_mailbox.go @@ -1,7 +1,6 @@ package tools import ( - "bufio" "encoding/json" "fmt" "os" @@ -10,6 +9,9 @@ import ( "strconv" "strings" "sync" + "time" + + "github.com/YspCoder/clawgo/pkg/jsonlog" ) type AgentThread struct { @@ -36,15 +38,39 @@ type AgentMessage struct { CreatedAt int64 `json:"created_at"` } +type mailboxThreadsMetaFile struct { + Version int `json:"version"` + LastOffset int64 `json:"last_offset,omitempty"` + ThreadSeq int `json:"thread_seq,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + Threads []AgentThread `json:"threads,omitempty"` +} + +type mailboxMessagesMetaFile struct { + Version int `json:"version"` + LastOffset int64 `json:"last_offset,omitempty"` + MsgSeq int `json:"msg_seq,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + Messages []AgentMessage `json:"messages,omitempty"` + ThreadMessages map[string][]string `json:"thread_messages,omitempty"` + QueuedByAgent map[string][]string `json:"queued_by_agent,omitempty"` + QueuedByThreadAgent map[string][]string `json:"queued_by_thread_agent,omitempty"` +} + type AgentMailboxStore struct { - dir string - threadsPath string - msgsPath string - mu sync.RWMutex - threads map[string]*AgentThread - messages map[string]*AgentMessage - msgSeq int - threadSeq int + dir string + threadsPath string + msgsPath string + threadsMetaPath string + msgsMetaPath string + mu sync.RWMutex + threads map[string]*AgentThread + messages map[string]*AgentMessage + threadMessages map[string][]string + queuedByAgent map[string][]string + queuedByThreadAgent map[string][]string + msgSeq int + threadSeq int } func NewAgentMailboxStore(workspace string) *AgentMailboxStore { @@ -54,13 +80,18 @@ func NewAgentMailboxStore(workspace string) *AgentMailboxStore { } dir := filepath.Join(workspace, "agents", "runtime") s := &AgentMailboxStore{ - dir: dir, - threadsPath: filepath.Join(dir, "threads.jsonl"), - msgsPath: filepath.Join(dir, "agent_messages.jsonl"), - threads: map[string]*AgentThread{}, - messages: map[string]*AgentMessage{}, + dir: dir, + threadsPath: filepath.Join(dir, "threads.jsonl"), + msgsPath: filepath.Join(dir, "agent_messages.jsonl"), + threadsMetaPath: filepath.Join(dir, "threads.meta.json"), + msgsMetaPath: filepath.Join(dir, "agent_messages.meta.json"), + threads: map[string]*AgentThread{}, + messages: map[string]*AgentMessage{}, + threadMessages: map[string][]string{}, + queuedByAgent: map[string][]string{}, + queuedByThreadAgent: map[string][]string{}, } - _ = os.MkdirAll(dir, 0755) + _ = os.MkdirAll(dir, 0o755) _ = s.load() return s } @@ -68,75 +99,91 @@ func NewAgentMailboxStore(workspace string) *AgentMailboxStore { func (s *AgentMailboxStore) load() error { s.mu.Lock() defer s.mu.Unlock() - s.threads = map[string]*AgentThread{} - s.messages = map[string]*AgentMessage{} + s.resetLocked() if err := s.loadThreadsLocked(); err != nil { return err } - return s.scanMessagesLocked() + return s.loadMessagesLocked() +} + +func (s *AgentMailboxStore) resetLocked() { + s.threads = map[string]*AgentThread{} + s.messages = map[string]*AgentMessage{} + s.threadMessages = map[string][]string{} + s.queuedByAgent = map[string][]string{} + s.queuedByThreadAgent = map[string][]string{} + s.msgSeq = 0 + s.threadSeq = 0 } func (s *AgentMailboxStore) loadThreadsLocked() error { - f, err := os.Open(s.threadsPath) + size, err := jsonlog.FileSize(s.threadsPath) if err != nil { - if os.IsNotExist(err) { - return nil - } return err } - defer f.Close() - scanner := bufio.NewScanner(f) - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 2*1024*1024) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" { - continue + var meta mailboxThreadsMetaFile + if err := jsonlog.ReadJSON(s.threadsMetaPath, &meta); err == nil && meta.Version > 0 && meta.LastOffset == size { + for _, thread := range meta.Threads { + cp := thread + cp.Participants = append([]string(nil), thread.Participants...) + s.threads[cp.ThreadID] = &cp } + s.threadSeq = meta.ThreadSeq + return nil + } + if size == 0 { + return s.persistThreadsMetaLocked(size) + } + if err := jsonlog.Scan(s.threadsPath, func(line []byte) error { var thread AgentThread - if err := json.Unmarshal([]byte(line), &thread); err != nil { - continue + if err := json.Unmarshal(line, &thread); err != nil || strings.TrimSpace(thread.ThreadID) == "" { + return nil } cp := thread + cp.Participants = append([]string(nil), thread.Participants...) s.threads[thread.ThreadID] = &cp if n := parseThreadSequence(thread.ThreadID); n > s.threadSeq { s.threadSeq = n } - } - return scanner.Err() -} - -func (s *AgentMailboxStore) scanMessagesLocked() error { - f, err := os.Open(s.msgsPath) - if err != nil { - if os.IsNotExist(err) { - return nil - } + return nil + }); err != nil { return err } - defer f.Close() - scanner := bufio.NewScanner(f) - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 2*1024*1024) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" { - continue - } - var msg AgentMessage - if err := json.Unmarshal([]byte(line), &msg); err != nil { - continue - } - if n := parseMessageSequence(msg.MessageID); n > s.msgSeq { - s.msgSeq = n - } - cp := msg - s.messages[msg.MessageID] = &cp - if thread := s.threads[msg.ThreadID]; thread != nil && msg.CreatedAt > thread.UpdatedAt { - thread.UpdatedAt = msg.CreatedAt - } + return s.persistThreadsMetaLocked(size) +} + +func (s *AgentMailboxStore) loadMessagesLocked() error { + size, err := jsonlog.FileSize(s.msgsPath) + if err != nil { + return err } - return scanner.Err() + var meta mailboxMessagesMetaFile + if err := jsonlog.ReadJSON(s.msgsMetaPath, &meta); err == nil && meta.Version > 0 && meta.LastOffset == size { + for _, msg := range meta.Messages { + cp := msg + s.messages[msg.MessageID] = &cp + } + s.threadMessages = cloneStringSliceMap(meta.ThreadMessages) + s.queuedByAgent = cloneStringSliceMap(meta.QueuedByAgent) + s.queuedByThreadAgent = cloneStringSliceMap(meta.QueuedByThreadAgent) + s.msgSeq = meta.MsgSeq + s.reconcileThreadsFromMessagesLocked() + return nil + } + if size == 0 { + return s.persistMessagesMetaLocked(size) + } + if err := jsonlog.Scan(s.msgsPath, func(line []byte) error { + var msg AgentMessage + if err := json.Unmarshal(line, &msg); err != nil || strings.TrimSpace(msg.MessageID) == "" { + return nil + } + s.indexMessageLocked(msg) + return nil + }); err != nil { + return err + } + return s.persistMessagesMetaLocked(size) } func (s *AgentMailboxStore) EnsureThread(thread AgentThread) (AgentThread, error) { @@ -145,7 +192,7 @@ func (s *AgentMailboxStore) EnsureThread(thread AgentThread) (AgentThread, error } s.mu.Lock() defer s.mu.Unlock() - if err := os.MkdirAll(s.dir, 0755); err != nil { + if err := os.MkdirAll(s.dir, 0o755); err != nil { return AgentThread{}, err } if strings.TrimSpace(thread.ThreadID) == "" { @@ -165,20 +212,19 @@ func (s *AgentMailboxStore) EnsureThread(thread AgentThread) (AgentThread, error if thread.UpdatedAt <= 0 { thread.UpdatedAt = thread.CreatedAt } - data, err := json.Marshal(thread) + offset, err := jsonlog.AppendLine(s.threadsPath, thread) if err != nil { return AgentThread{}, err } - f, err := os.OpenFile(s.threadsPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) - if err != nil { - return AgentThread{}, err - } - defer f.Close() - if _, err := f.Write(append(data, '\n')); err != nil { - return AgentThread{}, err - } cp := thread + cp.Participants = append([]string(nil), thread.Participants...) s.threads[thread.ThreadID] = &cp + if n := parseThreadSequence(thread.ThreadID); n > s.threadSeq { + s.threadSeq = n + } + if err := s.persistThreadsMetaLocked(offset); err != nil { + return AgentThread{}, err + } return thread, nil } @@ -188,7 +234,7 @@ func (s *AgentMailboxStore) AppendMessage(msg AgentMessage) (AgentMessage, error } s.mu.Lock() defer s.mu.Unlock() - if err := os.MkdirAll(s.dir, 0755); err != nil { + if err := os.MkdirAll(s.dir, 0o755); err != nil { return AgentMessage{}, err } if strings.TrimSpace(msg.MessageID) == "" { @@ -198,26 +244,17 @@ func (s *AgentMailboxStore) AppendMessage(msg AgentMessage) (AgentMessage, error if strings.TrimSpace(msg.Status) == "" { msg.Status = "queued" } - data, err := json.Marshal(msg) + offset, err := jsonlog.AppendLine(s.msgsPath, msg) if err != nil { return AgentMessage{}, err } - f, err := os.OpenFile(s.msgsPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) - if err != nil { + s.indexMessageLocked(msg) + if err := s.persistMessagesMetaLocked(offset); err != nil { return AgentMessage{}, err } - defer f.Close() - if _, err := f.Write(append(data, '\n')); err != nil { + if err := s.persistThreadsMetaLocked(0); err != nil { return AgentMessage{}, err } - if thread := s.threads[msg.ThreadID]; thread != nil { - thread.UpdatedAt = msg.CreatedAt - participants := append([]string(nil), thread.Participants...) - participants = append(participants, msg.FromAgent, msg.ToAgent) - thread.Participants = normalizeStringList(participants) - } - cp := msg - s.messages[msg.MessageID] = &cp return msg, nil } @@ -240,30 +277,21 @@ func (s *AgentMailboxStore) MessagesByThread(threadID string, limit int) ([]Agen if s == nil { return nil, nil } - return s.currentMessages(func(msg AgentMessage) bool { - return msg.ThreadID == strings.TrimSpace(threadID) - }, limit), nil + return s.currentIndexedMessages(s.threadMessages[strings.TrimSpace(threadID)], limit), nil } func (s *AgentMailboxStore) Inbox(agentID string, limit int) ([]AgentMessage, error) { if s == nil { return nil, nil } - agentID = strings.TrimSpace(agentID) - return s.currentMessages(func(msg AgentMessage) bool { - return msg.ToAgent == agentID && strings.EqualFold(strings.TrimSpace(msg.Status), "queued") - }, limit), nil + return s.currentIndexedMessages(s.queuedByAgent[strings.TrimSpace(agentID)], limit), nil } func (s *AgentMailboxStore) ThreadInbox(threadID, agentID string, limit int) ([]AgentMessage, error) { if s == nil { return nil, nil } - threadID = strings.TrimSpace(threadID) - agentID = strings.TrimSpace(agentID) - return s.currentMessages(func(msg AgentMessage) bool { - return msg.ThreadID == threadID && msg.ToAgent == agentID && strings.EqualFold(strings.TrimSpace(msg.Status), "queued") - }, limit), nil + return s.currentIndexedMessages(s.queuedByThreadAgent[mailboxThreadAgentKey(threadID, agentID)], limit), nil } func (s *AgentMailboxStore) Message(messageID string) (*AgentMessage, bool) { @@ -305,16 +333,16 @@ func (s *AgentMailboxStore) UpdateMessageStatus(messageID, status string, at int return &msg, nil } -func (s *AgentMailboxStore) currentMessages(match func(AgentMessage) bool, limit int) []AgentMessage { +func (s *AgentMailboxStore) currentIndexedMessages(ids []string, limit int) []AgentMessage { s.mu.RLock() defer s.mu.RUnlock() - var out []AgentMessage - for _, item := range s.messages { - msg := *item - if match != nil && !match(msg) { + out := make([]AgentMessage, 0, len(ids)) + for _, id := range ids { + msg := s.messages[id] + if msg == nil { continue } - out = append(out, msg) + out = append(out, *msg) } sort.Slice(out, func(i, j int) bool { if out[i].CreatedAt != out[j].CreatedAt { @@ -328,6 +356,136 @@ func (s *AgentMailboxStore) currentMessages(match func(AgentMessage) bool, limit return out } +func (s *AgentMailboxStore) indexMessageLocked(msg AgentMessage) { + msg.MessageID = strings.TrimSpace(msg.MessageID) + if msg.MessageID == "" { + return + } + if n := parseMessageSequence(msg.MessageID); n > s.msgSeq { + s.msgSeq = n + } + if existing := s.messages[msg.MessageID]; existing != nil { + s.removeQueuedIndexesLocked(*existing) + } + cp := msg + s.messages[msg.MessageID] = &cp + s.threadMessages[msg.ThreadID] = appendUniqueString(s.threadMessages[msg.ThreadID], msg.MessageID) + if strings.EqualFold(strings.TrimSpace(msg.Status), "queued") { + s.queuedByAgent[msg.ToAgent] = appendUniqueString(s.queuedByAgent[msg.ToAgent], msg.MessageID) + s.queuedByThreadAgent[mailboxThreadAgentKey(msg.ThreadID, msg.ToAgent)] = appendUniqueString(s.queuedByThreadAgent[mailboxThreadAgentKey(msg.ThreadID, msg.ToAgent)], msg.MessageID) + } + thread := s.ensureThreadStateLocked(msg.ThreadID) + if thread != nil { + if msg.CreatedAt > thread.UpdatedAt { + thread.UpdatedAt = msg.CreatedAt + } + participants := append([]string(nil), thread.Participants...) + participants = append(participants, msg.FromAgent, msg.ToAgent) + thread.Participants = normalizeStringList(participants) + } +} + +func (s *AgentMailboxStore) ensureThreadStateLocked(threadID string) *AgentThread { + threadID = strings.TrimSpace(threadID) + if threadID == "" { + return nil + } + thread := s.threads[threadID] + if thread != nil { + return thread + } + thread = &AgentThread{ + ThreadID: threadID, + Status: "open", + CreatedAt: 1, + UpdatedAt: 1, + Participants: nil, + } + s.threads[threadID] = thread + if n := parseThreadSequence(threadID); n > s.threadSeq { + s.threadSeq = n + } + return thread +} + +func (s *AgentMailboxStore) removeQueuedIndexesLocked(msg AgentMessage) { + if !strings.EqualFold(strings.TrimSpace(msg.Status), "queued") { + return + } + s.queuedByAgent[msg.ToAgent] = removeStringValue(s.queuedByAgent[msg.ToAgent], msg.MessageID) + s.queuedByThreadAgent[mailboxThreadAgentKey(msg.ThreadID, msg.ToAgent)] = removeStringValue(s.queuedByThreadAgent[mailboxThreadAgentKey(msg.ThreadID, msg.ToAgent)], msg.MessageID) +} + +func (s *AgentMailboxStore) reconcileThreadsFromMessagesLocked() { + for _, msg := range s.messages { + s.ensureThreadStateLocked(msg.ThreadID) + } + for _, msg := range s.messages { + thread := s.threads[msg.ThreadID] + if thread == nil { + continue + } + if msg.CreatedAt > thread.UpdatedAt { + thread.UpdatedAt = msg.CreatedAt + } + participants := append([]string(nil), thread.Participants...) + participants = append(participants, msg.FromAgent, msg.ToAgent) + thread.Participants = normalizeStringList(participants) + } +} + +func (s *AgentMailboxStore) persistThreadsMetaLocked(offset int64) error { + if offset <= 0 { + size, err := jsonlog.FileSize(s.threadsPath) + if err != nil { + return err + } + offset = size + } + meta := mailboxThreadsMetaFile{ + Version: 1, + LastOffset: offset, + ThreadSeq: s.threadSeq, + UpdatedAt: time.Now().UnixMilli(), + Threads: make([]AgentThread, 0, len(s.threads)), + } + for _, thread := range s.threads { + cp := *thread + cp.Participants = append([]string(nil), thread.Participants...) + meta.Threads = append(meta.Threads, cp) + } + sort.Slice(meta.Threads, func(i, j int) bool { + if meta.Threads[i].UpdatedAt != meta.Threads[j].UpdatedAt { + return meta.Threads[i].UpdatedAt > meta.Threads[j].UpdatedAt + } + return meta.Threads[i].ThreadID < meta.Threads[j].ThreadID + }) + return jsonlog.WriteJSON(s.threadsMetaPath, meta) +} + +func (s *AgentMailboxStore) persistMessagesMetaLocked(offset int64) error { + meta := mailboxMessagesMetaFile{ + Version: 1, + LastOffset: offset, + MsgSeq: s.msgSeq, + UpdatedAt: time.Now().UnixMilli(), + Messages: make([]AgentMessage, 0, len(s.messages)), + ThreadMessages: cloneStringSliceMap(s.threadMessages), + QueuedByAgent: cloneStringSliceMap(s.queuedByAgent), + QueuedByThreadAgent: cloneStringSliceMap(s.queuedByThreadAgent), + } + for _, msg := range s.messages { + meta.Messages = append(meta.Messages, *msg) + } + sort.Slice(meta.Messages, func(i, j int) bool { + if meta.Messages[i].CreatedAt != meta.Messages[j].CreatedAt { + return meta.Messages[i].CreatedAt < meta.Messages[j].CreatedAt + } + return meta.Messages[i].MessageID < meta.Messages[j].MessageID + }) + return jsonlog.WriteJSON(s.msgsMetaPath, meta) +} + func parseThreadSequence(threadID string) int { threadID = strings.TrimSpace(threadID) if !strings.HasPrefix(threadID, "thread-") { @@ -386,3 +544,45 @@ func parseMessageSequence(messageID string) int { n, _ := strconv.Atoi(strings.TrimPrefix(messageID, "msg-")) return n } + +func mailboxThreadAgentKey(threadID, agentID string) string { + return strings.TrimSpace(threadID) + "::" + strings.TrimSpace(agentID) +} + +func cloneStringSliceMap(src map[string][]string) map[string][]string { + if len(src) == 0 { + return map[string][]string{} + } + out := make(map[string][]string, len(src)) + for key, values := range src { + out[key] = append([]string(nil), values...) + } + return out +} + +func appendUniqueString(values []string, value string) []string { + value = strings.TrimSpace(value) + if value == "" { + return values + } + for _, existing := range values { + if existing == value { + return values + } + } + return append(values, value) +} + +func removeStringValue(values []string, value string) []string { + if len(values) == 0 { + return values + } + out := values[:0] + for _, existing := range values { + if existing == value { + continue + } + out = append(out, existing) + } + return append([]string(nil), out...) +} diff --git a/pkg/tools/subagent_mailbox_test.go b/pkg/tools/subagent_mailbox_test.go new file mode 100644 index 0000000..0645cfb --- /dev/null +++ b/pkg/tools/subagent_mailbox_test.go @@ -0,0 +1,81 @@ +package tools + +import ( + "os" + "path/filepath" + "testing" +) + +func TestAgentMailboxStoreReloadsInboxIndexes(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + store := NewAgentMailboxStore(workspace) + + thread, err := store.EnsureThread(AgentThread{ + Owner: "planner", + Topic: "handoff", + CreatedAt: 10, + UpdatedAt: 10, + }) + if err != nil { + t.Fatalf("ensure thread: %v", err) + } + msg, err := store.AppendMessage(AgentMessage{ + ThreadID: thread.ThreadID, + FromAgent: "planner", + ToAgent: "worker-a", + Type: "task", + Content: "check deploy logs", + RequiresReply: true, + Status: "queued", + CreatedAt: 20, + }) + if err != nil { + t.Fatalf("append message: %v", err) + } + + inbox, err := store.Inbox("worker-a", 10) + if err != nil { + t.Fatalf("inbox: %v", err) + } + if len(inbox) != 1 || inbox[0].MessageID != msg.MessageID { + t.Fatalf("unexpected inbox: %#v", inbox) + } + + reloaded := NewAgentMailboxStore(workspace) + threadInbox, err := reloaded.ThreadInbox(thread.ThreadID, "worker-a", 10) + if err != nil { + t.Fatalf("thread inbox: %v", err) + } + if len(threadInbox) != 1 || threadInbox[0].Content != "check deploy logs" { + t.Fatalf("unexpected thread inbox: %#v", threadInbox) + } + + if _, err := reloaded.UpdateMessageStatus(msg.MessageID, "processed", 30); err != nil { + t.Fatalf("update status: %v", err) + } + inbox, err = reloaded.Inbox("worker-a", 10) + if err != nil { + t.Fatalf("inbox after status update: %v", err) + } + if len(inbox) != 0 { + t.Fatalf("expected empty inbox after status update, got %#v", inbox) + } + + runtimeDir := filepath.Join(workspace, "agents", "runtime") + if err := os.Remove(filepath.Join(runtimeDir, "threads.meta.json")); err != nil { + t.Fatalf("remove thread meta: %v", err) + } + if err := os.Remove(filepath.Join(runtimeDir, "agent_messages.meta.json")); err != nil { + t.Fatalf("remove messages meta: %v", err) + } + rebuilt := NewAgentMailboxStore(workspace) + messages, err := rebuilt.MessagesByThread(thread.ThreadID, 10) + if err != nil { + t.Fatalf("messages by thread: %v", err) + } + if len(messages) != 1 || messages[0].Status != "processed" { + t.Fatalf("unexpected rebuilt messages: %#v", messages) + } +} diff --git a/pkg/tools/subagent_profile.go b/pkg/tools/subagent_profile.go index 765c404..541399e 100644 --- a/pkg/tools/subagent_profile.go +++ b/pkg/tools/subagent_profile.go @@ -16,24 +16,25 @@ import ( ) type SubagentProfile struct { - AgentID string `json:"agent_id"` - Name string `json:"name"` - Transport string `json:"transport,omitempty"` - ParentAgentID string `json:"parent_agent_id,omitempty"` - NotifyMainPolicy string `json:"notify_main_policy,omitempty"` - Role string `json:"role,omitempty"` - SystemPromptFile string `json:"system_prompt_file,omitempty"` - ToolAllowlist []string `json:"tool_allowlist,omitempty"` - MemoryNamespace string `json:"memory_namespace,omitempty"` - MaxRetries int `json:"max_retries,omitempty"` - RetryBackoff int `json:"retry_backoff_ms,omitempty"` - TimeoutSec int `json:"timeout_sec,omitempty"` - MaxTaskChars int `json:"max_task_chars,omitempty"` - MaxResultChars int `json:"max_result_chars,omitempty"` - Status string `json:"status"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` - ManagedBy string `json:"managed_by,omitempty"` + AgentID string `json:"agent_id"` + Name string `json:"name"` + Transport string `json:"transport,omitempty"` + ParentAgentID string `json:"parent_agent_id,omitempty"` + NotifyMainPolicy string `json:"notify_main_policy,omitempty"` + Role string `json:"role,omitempty"` + SystemPromptFile string `json:"system_prompt_file,omitempty"` + ToolAllowlist []string `json:"tool_allowlist,omitempty"` + MemoryNamespace string `json:"memory_namespace,omitempty"` + MaxRetries int `json:"max_retries,omitempty"` + RetryBackoff int `json:"retry_backoff_ms,omitempty"` + TimeoutSec int `json:"timeout_sec,omitempty"` + MaxToolIterations int `json:"max_tool_iterations,omitempty"` + MaxTaskChars int `json:"max_task_chars,omitempty"` + MaxResultChars int `json:"max_result_chars,omitempty"` + Status string `json:"status"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + ManagedBy string `json:"managed_by,omitempty"` } type SubagentProfileStore struct { @@ -192,6 +193,7 @@ func normalizeSubagentProfile(in SubagentProfile) SubagentProfile { p.MaxRetries = clampInt(p.MaxRetries, 0, 8) p.RetryBackoff = clampInt(p.RetryBackoff, 500, 120000) p.TimeoutSec = clampInt(p.TimeoutSec, 0, 3600) + p.MaxToolIterations = clampInt(p.MaxToolIterations, 0, 200) p.MaxTaskChars = clampInt(p.MaxTaskChars, 0, 400000) p.MaxResultChars = clampInt(p.MaxResultChars, 0, 400000) return p @@ -336,22 +338,23 @@ func profileFromConfig(agentID string, subcfg config.SubagentConfig) SubagentPro status = "disabled" } return normalizeSubagentProfile(SubagentProfile{ - AgentID: agentID, - Name: strings.TrimSpace(subcfg.DisplayName), - Transport: strings.TrimSpace(subcfg.Transport), - ParentAgentID: strings.TrimSpace(subcfg.ParentAgentID), - NotifyMainPolicy: strings.TrimSpace(subcfg.NotifyMainPolicy), - Role: strings.TrimSpace(subcfg.Role), - SystemPromptFile: strings.TrimSpace(subcfg.SystemPromptFile), - ToolAllowlist: append([]string(nil), subcfg.Tools.Allowlist...), - MemoryNamespace: strings.TrimSpace(subcfg.MemoryNamespace), - MaxRetries: subcfg.Runtime.MaxRetries, - RetryBackoff: subcfg.Runtime.RetryBackoffMs, - TimeoutSec: subcfg.Runtime.TimeoutSec, - MaxTaskChars: subcfg.Runtime.MaxTaskChars, - MaxResultChars: subcfg.Runtime.MaxResultChars, - Status: status, - ManagedBy: "config.json", + AgentID: agentID, + Name: strings.TrimSpace(subcfg.DisplayName), + Transport: strings.TrimSpace(subcfg.Transport), + ParentAgentID: strings.TrimSpace(subcfg.ParentAgentID), + NotifyMainPolicy: strings.TrimSpace(subcfg.NotifyMainPolicy), + Role: strings.TrimSpace(subcfg.Role), + SystemPromptFile: strings.TrimSpace(subcfg.SystemPromptFile), + ToolAllowlist: append([]string(nil), subcfg.Tools.Allowlist...), + MemoryNamespace: strings.TrimSpace(subcfg.MemoryNamespace), + MaxRetries: subcfg.Runtime.MaxRetries, + RetryBackoff: subcfg.Runtime.RetryBackoffMs, + TimeoutSec: subcfg.Runtime.TimeoutSec, + MaxToolIterations: subcfg.Runtime.MaxToolIterations, + MaxTaskChars: subcfg.Runtime.MaxTaskChars, + MaxResultChars: subcfg.Runtime.MaxResultChars, + Status: status, + ManagedBy: "config.json", }) } @@ -389,11 +392,12 @@ func (t *SubagentProfileTool) Parameters() map[string]interface{} { "description": "Tool allowlist entries. Supports tool names, '*'/'all', and grouped tokens like 'group:files_read'.", "items": map[string]interface{}{"type": "string"}, }, - "max_retries": map[string]interface{}{"type": "integer", "description": "Retry limit for subagent task execution."}, - "retry_backoff_ms": map[string]interface{}{"type": "integer", "description": "Backoff between retries in milliseconds."}, - "timeout_sec": map[string]interface{}{"type": "integer", "description": "Per-attempt timeout in seconds."}, - "max_task_chars": map[string]interface{}{"type": "integer", "description": "Task input size quota (characters)."}, - "max_result_chars": map[string]interface{}{"type": "integer", "description": "Result output size quota (characters)."}, + "max_retries": map[string]interface{}{"type": "integer", "description": "Retry limit for subagent task execution."}, + "retry_backoff_ms": map[string]interface{}{"type": "integer", "description": "Backoff between retries in milliseconds."}, + "timeout_sec": map[string]interface{}{"type": "integer", "description": "Per-attempt timeout in seconds."}, + "max_tool_iterations": map[string]interface{}{"type": "integer", "description": "Independent tool-calling iteration budget for this subagent."}, + "max_task_chars": map[string]interface{}{"type": "integer", "description": "Task input size quota (characters)."}, + "max_result_chars": map[string]interface{}{"type": "integer", "description": "Result output size quota (characters)."}, }, "required": []string{"action"}, } @@ -445,19 +449,20 @@ func (t *SubagentProfileTool) Execute(ctx context.Context, args map[string]inter return "subagent profile already exists", nil } p := SubagentProfile{ - AgentID: agentID, - Name: stringArg(args, "name"), - NotifyMainPolicy: stringArg(args, "notify_main_policy"), - Role: stringArg(args, "role"), - SystemPromptFile: stringArg(args, "system_prompt_file"), - MemoryNamespace: stringArg(args, "memory_namespace"), - Status: stringArg(args, "status"), - ToolAllowlist: parseStringList(args["tool_allowlist"]), - MaxRetries: profileIntArg(args, "max_retries"), - RetryBackoff: profileIntArg(args, "retry_backoff_ms"), - TimeoutSec: profileIntArg(args, "timeout_sec"), - MaxTaskChars: profileIntArg(args, "max_task_chars"), - MaxResultChars: profileIntArg(args, "max_result_chars"), + AgentID: agentID, + Name: stringArg(args, "name"), + NotifyMainPolicy: stringArg(args, "notify_main_policy"), + Role: stringArg(args, "role"), + SystemPromptFile: stringArg(args, "system_prompt_file"), + MemoryNamespace: stringArg(args, "memory_namespace"), + Status: stringArg(args, "status"), + ToolAllowlist: parseStringList(args["tool_allowlist"]), + MaxRetries: profileIntArg(args, "max_retries"), + RetryBackoff: profileIntArg(args, "retry_backoff_ms"), + TimeoutSec: profileIntArg(args, "timeout_sec"), + MaxToolIterations: profileIntArg(args, "max_tool_iterations"), + MaxTaskChars: profileIntArg(args, "max_task_chars"), + MaxResultChars: profileIntArg(args, "max_result_chars"), } saved, err := t.store.Upsert(p) if err != nil { @@ -506,6 +511,9 @@ func (t *SubagentProfileTool) Execute(ctx context.Context, args map[string]inter if _, ok := args["timeout_sec"]; ok { next.TimeoutSec = profileIntArg(args, "timeout_sec") } + if _, ok := args["max_tool_iterations"]; ok { + next.MaxToolIterations = profileIntArg(args, "max_tool_iterations") + } if _, ok := args["max_task_chars"]; ok { next.MaxTaskChars = profileIntArg(args, "max_task_chars") } diff --git a/pkg/tools/subagent_runtime_context.go b/pkg/tools/subagent_runtime_context.go new file mode 100644 index 0000000..08b795b --- /dev/null +++ b/pkg/tools/subagent_runtime_context.go @@ -0,0 +1,51 @@ +package tools + +import "context" + +type SubagentExecutionStats struct { + Iterations int + Attempts int + Restarts int + FailureCode string +} + +type subagentExecutionStatsKey struct{} +type subagentIterationBudgetKey struct{} + +func WithSubagentExecutionStats(ctx context.Context, stats *SubagentExecutionStats) context.Context { + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, subagentExecutionStatsKey{}, stats) +} + +func RecordSubagentExecutionStats(ctx context.Context, delta SubagentExecutionStats) { + if ctx == nil { + return + } + stats, _ := ctx.Value(subagentExecutionStatsKey{}).(*SubagentExecutionStats) + if stats == nil { + return + } + stats.Iterations += delta.Iterations + stats.Attempts += delta.Attempts + stats.Restarts += delta.Restarts + if delta.FailureCode != "" { + stats.FailureCode = delta.FailureCode + } +} + +func WithSubagentIterationBudget(ctx context.Context, budget int) context.Context { + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, subagentIterationBudgetKey{}, budget) +} + +func SubagentIterationBudget(ctx context.Context) (int, bool) { + if ctx == nil { + return 0, false + } + budget, ok := ctx.Value(subagentIterationBudgetKey{}).(int) + return budget, ok && budget > 0 +} diff --git a/pkg/tools/subagent_store.go b/pkg/tools/subagent_store.go index 6005e12..1f49452 100644 --- a/pkg/tools/subagent_store.go +++ b/pkg/tools/subagent_store.go @@ -1,7 +1,6 @@ package tools import ( - "bufio" "encoding/json" "os" "path/filepath" @@ -9,24 +8,47 @@ import ( "strconv" "strings" "sync" + "time" + + "github.com/YspCoder/clawgo/pkg/jsonlog" ) type SubagentRunEvent struct { - RunID string `json:"run_id"` - AgentID string `json:"agent_id,omitempty"` - Type string `json:"type"` - Status string `json:"status,omitempty"` - Message string `json:"message,omitempty"` - RetryCount int `json:"retry_count,omitempty"` - At int64 `json:"ts"` + RunID string `json:"run_id"` + AgentID string `json:"agent_id,omitempty"` + Type string `json:"type"` + Status string `json:"status,omitempty"` + FailureCode string `json:"failure_code,omitempty"` + Message string `json:"message,omitempty"` + RetryCount int `json:"retry_count,omitempty"` + At int64 `json:"ts"` +} + +type subagentRunsMetaFile struct { + Version int `json:"version"` + LastOffset int64 `json:"last_offset,omitempty"` + NextIDSeed int `json:"next_id_seed,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + Runs []SubagentRun `json:"runs,omitempty"` +} + +type subagentEventsMetaFile struct { + Version int `json:"version"` + LastOffset int64 `json:"last_offset,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + EventsByRun map[string][]SubagentRunEvent `json:"events_by_run,omitempty"` } type SubagentRunStore struct { - dir string - runsPath string - eventsPath string - mu sync.RWMutex - runs map[string]*SubagentRun + dir string + runsPath string + eventsPath string + runsMetaPath string + eventsMetaPath string + mu sync.RWMutex + runs map[string]*SubagentRun + events map[string][]SubagentRunEvent + nextIDSeed int } func NewSubagentRunStore(workspace string) *SubagentRunStore { @@ -36,12 +58,16 @@ func NewSubagentRunStore(workspace string) *SubagentRunStore { } dir := filepath.Join(workspace, "agents", "runtime") store := &SubagentRunStore{ - dir: dir, - runsPath: filepath.Join(dir, "subagent_runs.jsonl"), - eventsPath: filepath.Join(dir, "subagent_events.jsonl"), - runs: map[string]*SubagentRun{}, + dir: dir, + runsPath: filepath.Join(dir, "subagent_runs.jsonl"), + eventsPath: filepath.Join(dir, "subagent_events.jsonl"), + runsMetaPath: filepath.Join(dir, "subagent_runs.meta.json"), + eventsMetaPath: filepath.Join(dir, "subagent_events.meta.json"), + runs: map[string]*SubagentRun{}, + events: map[string][]SubagentRunEvent{}, + nextIDSeed: 1, } - _ = os.MkdirAll(dir, 0755) + _ = os.MkdirAll(dir, 0o755) _ = store.load() return store } @@ -51,48 +77,80 @@ func (s *SubagentRunStore) load() error { defer s.mu.Unlock() s.runs = map[string]*SubagentRun{} - f, err := os.Open(s.runsPath) - if err != nil { - if os.IsNotExist(err) { - return nil - } + s.events = map[string][]SubagentRunEvent{} + s.nextIDSeed = 1 + + if err := s.loadRunsLocked(); err != nil { return err } - defer f.Close() + return s.loadEventsLocked() +} - scanner := bufio.NewScanner(f) - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 2*1024*1024) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" { - continue - } - var record RunRecord - if err := json.Unmarshal([]byte(line), &record); err == nil && strings.TrimSpace(record.ID) != "" { - run := &SubagentRun{ - ID: record.ID, - Task: record.Input, - AgentID: record.AgentID, - ThreadID: record.ThreadID, - CorrelationID: record.CorrelationID, - ParentRunID: record.ParentRunID, - Status: record.Status, - Result: record.Output, - Created: record.CreatedAt, - Updated: record.UpdatedAt, - } - s.runs[run.ID] = run - continue - } - var run SubagentRun - if err := json.Unmarshal([]byte(line), &run); err != nil { - continue - } - cp := cloneSubagentRun(&run) - s.runs[run.ID] = cp +func (s *SubagentRunStore) loadRunsLocked() error { + size, err := jsonlog.FileSize(s.runsPath) + if err != nil { + return err } - return scanner.Err() + var meta subagentRunsMetaFile + if err := jsonlog.ReadJSON(s.runsMetaPath, &meta); err == nil && meta.Version > 0 && meta.LastOffset == size { + for _, run := range meta.Runs { + cp := cloneSubagentRun(&run) + s.runs[cp.ID] = cp + } + s.nextIDSeed = meta.NextIDSeed + if s.nextIDSeed <= 0 { + s.nextIDSeed = deriveNextRunSeed(s.runs) + } + return nil + } + + if size == 0 { + return s.persistRunsMetaLocked(size) + } + if err := jsonlog.Scan(s.runsPath, func(line []byte) error { + run, ok := decodeSubagentRunLine(line) + if !ok { + return nil + } + s.runs[run.ID] = run + return nil + }); err != nil { + return err + } + s.nextIDSeed = deriveNextRunSeed(s.runs) + return s.persistRunsMetaLocked(size) +} + +func (s *SubagentRunStore) loadEventsLocked() error { + size, err := jsonlog.FileSize(s.eventsPath) + if err != nil { + return err + } + var meta subagentEventsMetaFile + if err := jsonlog.ReadJSON(s.eventsMetaPath, &meta); err == nil && meta.Version > 0 && meta.LastOffset == size { + s.events = cloneEventsByRun(meta.EventsByRun) + return nil + } + if size == 0 { + return s.persistEventsMetaLocked(size) + } + eventsByRun := map[string][]SubagentRunEvent{} + if err := jsonlog.Scan(s.eventsPath, func(line []byte) error { + evt, ok := decodeSubagentEventLine(line) + if !ok { + return nil + } + eventsByRun[evt.RunID] = append(eventsByRun[evt.RunID], evt) + return nil + }); err != nil { + return err + } + for runID, events := range eventsByRun { + sort.Slice(events, func(i, j int) bool { return events[i].At < events[j].At }) + eventsByRun[runID] = events + } + s.events = eventsByRun + return s.persistEventsMetaLocked(size) } func (s *SubagentRunStore) AppendRun(run *SubagentRun) error { @@ -100,26 +158,22 @@ func (s *SubagentRunStore) AppendRun(run *SubagentRun) error { return nil } cp := cloneSubagentRun(run) - data, err := json.Marshal(runToRunRecord(cp)) - if err != nil { - return err - } + record := runToRunRecord(cp) s.mu.Lock() defer s.mu.Unlock() - if err := os.MkdirAll(s.dir, 0755); err != nil { + if err := os.MkdirAll(s.dir, 0o755); err != nil { return err } - f, err := os.OpenFile(s.runsPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) + offset, err := jsonlog.AppendLine(s.runsPath, record) if err != nil { return err } - defer f.Close() - if _, err := f.Write(append(data, '\n')); err != nil { - return err - } s.runs[cp.ID] = cp - return nil + if next := parseSubagentSequence(cp.ID) + 1; next > s.nextIDSeed { + s.nextIDSeed = next + } + return s.persistRunsMetaLocked(offset) } func (s *SubagentRunStore) AppendEvent(evt SubagentRunEvent) error { @@ -127,32 +181,30 @@ func (s *SubagentRunStore) AppendEvent(evt SubagentRunEvent) error { return nil } record := EventRecord{ - ID: EventRecordID(evt.RunID, evt.Type, evt.At), - RunID: evt.RunID, - RequestID: evt.RunID, - AgentID: evt.AgentID, - Type: evt.Type, - Status: evt.Status, - Message: evt.Message, - RetryCount: evt.RetryCount, - At: evt.At, - } - data, err := json.Marshal(record) - if err != nil { - return err + ID: EventRecordID(evt.RunID, evt.Type, evt.At), + RunID: evt.RunID, + RequestID: evt.RunID, + AgentID: evt.AgentID, + Type: evt.Type, + Status: evt.Status, + FailureCode: evt.FailureCode, + Message: evt.Message, + RetryCount: evt.RetryCount, + At: evt.At, } + s.mu.Lock() defer s.mu.Unlock() - if err := os.MkdirAll(s.dir, 0755); err != nil { + if err := os.MkdirAll(s.dir, 0o755); err != nil { return err } - f, err := os.OpenFile(s.eventsPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) + offset, err := jsonlog.AppendLine(s.eventsPath, record) if err != nil { return err } - defer f.Close() - _, err = f.Write(append(data, '\n')) - return err + s.events[evt.RunID] = append(s.events[evt.RunID], evt) + sort.Slice(s.events[evt.RunID], func(i, j int) bool { return s.events[evt.RunID][i].At < s.events[evt.RunID][j].At }) + return s.persistEventsMetaLocked(offset) } func (s *SubagentRunStore) Get(runID string) (*SubagentRun, bool) { @@ -191,54 +243,14 @@ func (s *SubagentRunStore) Events(runID string, limit int) ([]SubagentRunEvent, if s == nil { return nil, nil } - f, err := os.Open(s.eventsPath) - if err != nil { - if os.IsNotExist(err) { - return nil, nil - } - return nil, err + s.mu.RLock() + defer s.mu.RUnlock() + items := append([]SubagentRunEvent(nil), s.events[strings.TrimSpace(runID)]...) + sort.Slice(items, func(i, j int) bool { return items[i].At < items[j].At }) + if limit > 0 && len(items) > limit { + items = items[len(items)-limit:] } - defer f.Close() - - runID = strings.TrimSpace(runID) - events := make([]SubagentRunEvent, 0) - scanner := bufio.NewScanner(f) - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 2*1024*1024) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" { - continue - } - var evt SubagentRunEvent - if err := json.Unmarshal([]byte(line), &evt); err != nil { - var record EventRecord - if err := json.Unmarshal([]byte(line), &record); err != nil { - continue - } - evt = SubagentRunEvent{ - RunID: record.RunID, - AgentID: record.AgentID, - Type: record.Type, - Status: record.Status, - Message: record.Message, - RetryCount: record.RetryCount, - At: record.At, - } - } - if evt.RunID != runID { - continue - } - events = append(events, evt) - } - if err := scanner.Err(); err != nil { - return nil, err - } - sort.Slice(events, func(i, j int) bool { return events[i].At < events[j].At }) - if limit > 0 && len(events) > limit { - events = events[len(events)-limit:] - } - return events, nil + return items, nil } func (s *SubagentRunStore) NextIDSeed() int { @@ -247,8 +259,46 @@ func (s *SubagentRunStore) NextIDSeed() int { } s.mu.RLock() defer s.mu.RUnlock() + if s.nextIDSeed <= 0 { + return 1 + } + return s.nextIDSeed +} + +func (s *SubagentRunStore) persistRunsMetaLocked(offset int64) error { + meta := subagentRunsMetaFile{ + Version: 1, + LastOffset: offset, + NextIDSeed: maxRunSeed(deriveNextRunSeed(s.runs), s.nextIDSeed, 1), + UpdatedAt: time.Now().UnixMilli(), + Runs: make([]SubagentRun, 0, len(s.runs)), + } + for _, run := range s.runs { + meta.Runs = append(meta.Runs, *cloneSubagentRun(run)) + } + sort.Slice(meta.Runs, func(i, j int) bool { + if meta.Runs[i].Created != meta.Runs[j].Created { + return meta.Runs[i].Created > meta.Runs[j].Created + } + return meta.Runs[i].ID > meta.Runs[j].ID + }) + s.nextIDSeed = meta.NextIDSeed + return jsonlog.WriteJSON(s.runsMetaPath, meta) +} + +func (s *SubagentRunStore) persistEventsMetaLocked(offset int64) error { + meta := subagentEventsMetaFile{ + Version: 1, + LastOffset: offset, + UpdatedAt: time.Now().UnixMilli(), + EventsByRun: cloneEventsByRun(s.events), + } + return jsonlog.WriteJSON(s.eventsMetaPath, meta) +} + +func deriveNextRunSeed(runs map[string]*SubagentRun) int { maxSeq := 0 - for runID := range s.runs { + for runID := range runs { if n := parseSubagentSequence(runID); n > maxSeq { maxSeq = n } @@ -268,6 +318,71 @@ func parseSubagentSequence(runID string) int { return n } +func decodeSubagentRunLine(line []byte) (*SubagentRun, bool) { + var record RunRecord + if err := json.Unmarshal(line, &record); err == nil && strings.TrimSpace(record.ID) != "" { + return &SubagentRun{ + ID: record.ID, + Task: record.Input, + AgentID: record.AgentID, + ThreadID: record.ThreadID, + CorrelationID: record.CorrelationID, + ParentRunID: record.ParentRunID, + Status: record.Status, + Result: record.Output, + Created: record.CreatedAt, + Updated: record.UpdatedAt, + }, true + } + var run SubagentRun + if err := json.Unmarshal(line, &run); err != nil || strings.TrimSpace(run.ID) == "" { + return nil, false + } + return cloneSubagentRun(&run), true +} + +func decodeSubagentEventLine(line []byte) (SubagentRunEvent, bool) { + var evt SubagentRunEvent + if err := json.Unmarshal(line, &evt); err == nil && strings.TrimSpace(evt.RunID) != "" { + return evt, true + } + var record EventRecord + if err := json.Unmarshal(line, &record); err != nil || strings.TrimSpace(record.RunID) == "" { + return SubagentRunEvent{}, false + } + return SubagentRunEvent{ + RunID: record.RunID, + AgentID: record.AgentID, + Type: record.Type, + Status: record.Status, + FailureCode: record.FailureCode, + Message: record.Message, + RetryCount: record.RetryCount, + At: record.At, + }, true +} + +func cloneEventsByRun(src map[string][]SubagentRunEvent) map[string][]SubagentRunEvent { + if len(src) == 0 { + return map[string][]SubagentRunEvent{} + } + out := make(map[string][]SubagentRunEvent, len(src)) + for key, items := range src { + out[key] = append([]SubagentRunEvent(nil), items...) + } + return out +} + +func maxRunSeed(values ...int) int { + best := 0 + for _, value := range values { + if value > best { + best = value + } + } + return best +} + func cloneSubagentRun(run *SubagentRun) *SubagentRun { if run == nil { return nil @@ -336,4 +451,3 @@ func runToRunRecord(run *SubagentRun) RunRecord { UpdatedAt: run.Updated, } } - diff --git a/pkg/tools/subagent_store_test.go b/pkg/tools/subagent_store_test.go new file mode 100644 index 0000000..9988fd3 --- /dev/null +++ b/pkg/tools/subagent_store_test.go @@ -0,0 +1,70 @@ +package tools + +import ( + "os" + "path/filepath" + "testing" +) + +func TestSubagentRunStoreReloadsFromMetaAndJSONL(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + store := NewSubagentRunStore(workspace) + run := &SubagentRun{ + ID: "subagent-0007", + Task: "review deploy plan", + AgentID: "worker-a", + Status: RuntimeStatusCompleted, + Result: "done", + Created: 10, + Updated: 20, + } + if err := store.AppendRun(run); err != nil { + t.Fatalf("append run: %v", err) + } + if err := store.AppendEvent(SubagentRunEvent{ + RunID: run.ID, + AgentID: run.AgentID, + Type: "completed", + Status: RuntimeStatusCompleted, + Message: "done", + At: 20, + }); err != nil { + t.Fatalf("append event: %v", err) + } + + reloaded := NewSubagentRunStore(workspace) + if got, ok := reloaded.Get(run.ID); !ok || got.Result != "done" { + t.Fatalf("expected run restored, got %#v ok=%v", got, ok) + } + if seed := reloaded.NextIDSeed(); seed != 8 { + t.Fatalf("expected next seed 8, got %d", seed) + } + events, err := reloaded.Events(run.ID, 10) + if err != nil { + t.Fatalf("events: %v", err) + } + if len(events) != 1 || events[0].Type != "completed" { + t.Fatalf("unexpected events: %#v", events) + } + + runtimeDir := filepath.Join(workspace, "agents", "runtime") + if err := os.Remove(filepath.Join(runtimeDir, "subagent_runs.meta.json")); err != nil { + t.Fatalf("remove run meta: %v", err) + } + if err := os.Remove(filepath.Join(runtimeDir, "subagent_events.meta.json")); err != nil { + t.Fatalf("remove event meta: %v", err) + } + rebuilt := NewSubagentRunStore(workspace) + if seed := rebuilt.NextIDSeed(); seed != 8 { + t.Fatalf("expected rebuilt next seed 8, got %d", seed) + } + events, err = rebuilt.Events(run.ID, 10) + if err != nil { + t.Fatalf("events after rebuild: %v", err) + } + if len(events) != 1 || events[0].Message != "done" { + t.Fatalf("unexpected rebuilt events: %#v", events) + } +} diff --git a/pkg/tools/tool_allowlist_groups.go b/pkg/tools/tool_allowlist_groups.go index a301c2b..715d2f5 100644 --- a/pkg/tools/tool_allowlist_groups.go +++ b/pkg/tools/tool_allowlist_groups.go @@ -27,9 +27,9 @@ var defaultToolAllowlistGroups = []ToolAllowlistGroup{ }, { Name: "memory_read", - Description: "Read-only memory tools", + Description: "Read-only memory and session recall tools", Aliases: []string{"mem_read"}, - Tools: []string{"memory_search", "memory_get"}, + Tools: []string{"memory_search", "memory_get", "session_search"}, }, { Name: "memory_write",