From a9169c66ff809657b216bef4e711a345395abc20 Mon Sep 17 00:00:00 2001 From: lpf Date: Wed, 8 Apr 2026 15:25:28 +0800 Subject: [PATCH 1/5] Release v1.0.2 --- cmd/cmd_gateway.go | 1 + config.example.json | 12 + pkg/agent/loop.go | 1024 ++++++++++++++++---------- pkg/agent/loop_codex_options_test.go | 215 +++++- pkg/api/server.go | 533 +++++++++++++- pkg/api/server_test.go | 168 ++++- pkg/channels/weixin.go | 16 + pkg/config/config.go | 2 + pkg/config/normalized.go | 10 + pkg/config/normalized_test.go | 13 + pkg/config/validate.go | 3 + pkg/providers/codex_provider.go | 4 +- pkg/providers/codex_provider_test.go | 21 + pkg/providers/http_provider.go | 60 +- pkg/providers/oauth_test.go | 38 + 15 files changed, 1670 insertions(+), 450 deletions(-) diff --git a/cmd/cmd_gateway.go b/cmd/cmd_gateway.go index 5afda3a..d2adee4 100644 --- a/cmd/cmd_gateway.go +++ b/cmd/cmd_gateway.go @@ -275,6 +275,7 @@ func gatewayCmd() { registryServer.SetConfigAfterHook(func(forceRuntimeReload bool) error { return triggerReload("api", forceRuntimeReload) }) + registryServer.SetMessageBus(msgBus) if rawWeixin, ok := channelManager.GetChannel("weixin"); ok { if weixinChannel, ok := rawWeixin.(*channels.WeixinChannel); ok { weixinChannel.SetConfigPath(getConfigPath()) diff --git a/config.example.json b/config.example.json index 018b750..4eda1b1 100644 --- a/config.example.json +++ b/config.example.json @@ -178,6 +178,8 @@ "codex": { "api_base": "https://api.openai.com/v1", "models": ["gpt-5.4"], + "max_tokens": 8192, + "temperature": 0.7, "responses": { "web_search_enabled": false, "web_search_context_size": "", @@ -204,6 +206,8 @@ "gemini": { "api_base": "https://generativelanguage.googleapis.com/v1beta/openai", "models": ["gemini-2.5-pro"], + "max_tokens": 8192, + "temperature": 0.7, "responses": { "web_search_enabled": false, "web_search_context_size": "", @@ -231,6 +235,8 @@ "api_key": "sk-your-openai-api-key", "api_base": "https://api.openai.com/v1", "models": ["gpt-5.4", "gpt-5.4-mini"], + "max_tokens": 8192, + "temperature": 0.7, "responses": { "web_search_enabled": false, "web_search_context_size": "", @@ -245,6 +251,8 @@ "anthropic": { "api_base": "https://api.anthropic.com", "models": ["claude-sonnet-4-20250514"], + "max_tokens": 8192, + "temperature": 0.7, "responses": { "web_search_enabled": false, "web_search_context_size": "", @@ -270,6 +278,8 @@ "qwen": { "api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1", "models": ["qwen-max"], + "max_tokens": 8192, + "temperature": 0.7, "responses": { "web_search_enabled": false, "web_search_context_size": "", @@ -294,6 +304,8 @@ "kimi": { "api_base": "https://api.moonshot.cn/v1", "models": ["kimi-k2-0711-preview"], + "max_tokens": 8192, + "temperature": 0.7, "responses": { "web_search_enabled": false, "web_search_context_size": "", diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index f8113af..7c12080 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -40,6 +40,8 @@ type AgentLoop struct { provider providers.LLMProvider workspace string model string + maxTokens int + temperature float64 maxIterations int sessions *session.SessionManager contextBuilder *ContextBuilder @@ -56,6 +58,8 @@ type AgentLoop struct { providerNames []string providerPool map[string]providers.LLMProvider providerResponses map[string]config.ProviderResponsesConfig + providerMaxTokens map[string]int + providerTemperatures map[string]float64 telegramStreaming bool providerMu sync.RWMutex sessionProvider map[string]string @@ -222,6 +226,8 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers provider: provider, workspace: workspace, model: provider.GetDefaultModel(), + maxTokens: cfg.Agents.Defaults.MaxTokens, + temperature: cfg.Agents.Defaults.Temperature, maxIterations: cfg.Agents.Defaults.MaxToolIterations, sessions: sessionsManager, contextBuilder: NewContextBuilder(workspace, func() []string { return toolsRegistry.GetSummaries() }), @@ -237,6 +243,8 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers sessionProvider: map[string]string{}, sessionStreamed: map[string]bool{}, providerResponses: map[string]config.ProviderResponsesConfig{}, + providerMaxTokens: map[string]int{}, + providerTemperatures: map[string]float64{}, telegramStreaming: cfg.Channels.Telegram.Streaming, subagentManager: subagentManager, subagentRouter: subagentRouter, @@ -265,6 +273,8 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers loop.providerNames = append(loop.providerNames, primaryName) if pc, ok := config.ProviderConfigByName(cfg, primaryName); ok { loop.providerResponses[primaryName] = pc.Responses + loop.providerMaxTokens[primaryName] = pc.MaxTokens + loop.providerTemperatures[primaryName] = pc.Temperature } seenProviders := map[string]struct{}{primaryName: {}} providerConfigs := config.AllProviderConfigs(cfg) @@ -304,6 +314,8 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers modelName = strings.TrimSpace(pc.Models[0]) } loop.providerResponses[providerName] = pc.Responses + loop.providerMaxTokens[providerName] = pc.MaxTokens + loop.providerTemperatures[providerName] = pc.Temperature } seenProviders[providerName] = struct{}{} loop.providerNames = append(loop.providerNames, providerName) @@ -464,7 +476,7 @@ func (al *AgentLoop) buildSessionShards(ctx context.Context) []chan bus.InboundM return shards } -func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMessage, messages []providers.Message, toolDefs []providers.ToolDefinition, options map[string]interface{}, primaryErr error) (*providers.LLMResponse, string, error) { +func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMessage, messages []providers.Message, toolDefs []providers.ToolDefinition, primaryErr error) (*providers.LLMResponse, string, error) { if len(al.providerChain) <= 1 { return nil, "", primaryErr } @@ -495,8 +507,10 @@ func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMe lastErr = err continue } - resp, err := p.Chat(ctx, messages, toolDefs, candidateModel, options) + fallbackOptions := al.buildResponsesOptionsForProvider(msg.SessionKey, candidate.name, int64(al.maxTokensForProvider(candidate.name)), al.temperatureForProvider(candidate.name)) + resp, err := p.Chat(ctx, messages, toolDefs, candidateModel, fallbackOptions) if err == nil { + al.setSessionProvider(msg.SessionKey, candidate.name) logger.WarnCF("agent", logger.C0150, map[string]interface{}{"provider": candidate.name, "model": candidateModel, "ref": candidate.ref}) return resp, candidate.name, nil } @@ -550,6 +564,76 @@ func (al *AgentLoop) ensureProviderCandidate(candidate providerCandidate) (provi return created, model, nil } +func (al *AgentLoop) providerCandidateByName(name string) (providerCandidate, bool) { + if al == nil { + return providerCandidate{}, false + } + target := strings.TrimSpace(name) + if target == "" { + return providerCandidate{}, false + } + for _, candidate := range al.providerChain { + if strings.EqualFold(strings.TrimSpace(candidate.name), target) { + return candidate, true + } + } + return providerCandidate{}, false +} + +func (al *AgentLoop) defaultProviderName() string { + if al == nil || len(al.providerNames) == 0 { + return "" + } + return strings.TrimSpace(al.providerNames[0]) +} + +func (al *AgentLoop) sessionProviderName(sessionKey string) string { + name := strings.TrimSpace(al.getSessionProvider(sessionKey)) + if name == "" { + name = al.defaultProviderName() + } + return name +} + +func (al *AgentLoop) isKnownProviderName(name string) bool { + target := strings.TrimSpace(name) + if target == "" { + return false + } + for _, item := range al.providerNames { + if strings.EqualFold(strings.TrimSpace(item), target) { + return true + } + } + return false +} + +func (al *AgentLoop) activeProviderForSession(sessionKey string) (providers.LLMProvider, string, string, error) { + if al == nil { + return nil, "", "", fmt.Errorf("agent loop is nil") + } + name := al.sessionProviderName(sessionKey) + if name == "" { + return al.provider, al.model, "", nil + } + if strings.EqualFold(name, al.defaultProviderName()) { + model := strings.TrimSpace(al.model) + if model == "" && al.provider != nil { + model = strings.TrimSpace(al.provider.GetDefaultModel()) + } + return al.provider, model, name, nil + } + candidate, ok := al.providerCandidateByName(name) + if !ok { + return al.provider, al.model, name, nil + } + p, model, err := al.ensureProviderCandidate(candidate) + if err != nil { + return nil, "", name, err + } + return p, model, name, nil +} + func automaticFallbackPriority(name string) int { switch normalizeFallbackProviderName(name) { case "claude": @@ -610,10 +694,22 @@ func (al *AgentLoop) getSessionProvider(sessionKey string) string { } func (al *AgentLoop) syncSessionDefaultProvider(sessionKey string) { - if al == nil || len(al.providerNames) == 0 { + if al == nil { return } - al.setSessionProvider(sessionKey, al.providerNames[0]) + current := strings.TrimSpace(al.getSessionProvider(sessionKey)) + if current == "" { + if name := al.defaultProviderName(); name != "" { + al.setSessionProvider(sessionKey, name) + } + return + } + if al.isKnownProviderName(current) { + return + } + if name := al.defaultProviderName(); name != "" { + al.setSessionProvider(sessionKey, name) + } } func (al *AgentLoop) markSessionStreamed(sessionKey string) { @@ -707,6 +803,8 @@ func sessionShardIndex(sessionKey string, shardCount int) int { return int(h.Sum32() % uint32(shardCount)) } +var thinkTagPattern = regexp.MustCompile(`(?s).*?`) + func (al *AgentLoop) getTrigger(msg bus.InboundMessage) string { if msg.Metadata != nil { if t := strings.TrimSpace(msg.Metadata["trigger"]); t != "" { @@ -748,6 +846,413 @@ func (al *AgentLoop) shouldSuppressOutbound(msg bus.InboundMessage, response str return len(r) <= maxChars } +type llmTurnLoopConfig struct { + ctx context.Context + triggerMsg bus.InboundMessage + sessionKey string + toolChannel string + toolChatID string + messages []providers.Message + media []string + mediaItems []bus.MediaItem + enableStreaming bool + errorLogCode logger.CodeID + logDirectResponse bool +} + +type llmTurnLoopResult struct { + messages []providers.Message + pendingPersist []providers.Message + finalContent string + iteration int + hasToolActivity bool +} + +func logLLMTurnRequest(iteration, maxIterations int, providerName, activeModel string, messages []providers.Message, providerToolDefs []providers.ToolDefinition, maxTokens int, temperature float64) { + systemPromptLen := 0 + if len(messages) > 0 { + systemPromptLen = len(messages[0].Content) + } + logger.DebugCF("agent", logger.C0152, map[string]interface{}{ + "iteration": iteration, + "max": maxIterations, + "provider": providerName, + "model": activeModel, + "messages_count": len(messages), + "tools_count": len(providerToolDefs), + "max_tokens": maxTokens, + "temperature": temperature, + "system_prompt_len": systemPromptLen, + }) + if iteration == 1 { + logger.DebugCF("agent", logger.C0153, map[string]interface{}{ + "iteration": iteration, + "messages_json": formatMessagesForLog(messages), + "tools_json": formatToolsForLog(providerToolDefs), + }) + } +} + +func logLLMDirectResponse(iteration int, finalContent string) { + logger.InfoCF("agent", logger.C0156, map[string]interface{}{ + "iteration": iteration, + "content_chars": len(finalContent), + }) +} + +func logLLMToolCalls(iteration int, toolCalls []providers.ToolCall) { + toolNames := make([]string, 0, len(toolCalls)) + for _, tc := range toolCalls { + toolNames = append(toolNames, tc.Name) + } + logger.InfoCF("agent", logger.C0157, map[string]interface{}{ + "tools": toolNames, + "count": len(toolNames), + "iteration": iteration, + }) +} + +func buildAssistantToolCallMessage(response *providers.LLMResponse) providers.Message { + assistantMsg := providers.Message{ + Role: "assistant", + Content: response.Content, + } + if response == nil { + return assistantMsg + } + for _, tc := range response.ToolCalls { + argumentsJSON, _ := json.Marshal(tc.Arguments) + assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ + ID: tc.ID, + Type: "function", + Function: &providers.FunctionCall{ + Name: tc.Name, + Arguments: string(argumentsJSON), + }, + }) + } + return assistantMsg +} + +func (al *AgentLoop) executeResponseToolCalls(cfg llmTurnLoopConfig, iteration int, response *providers.LLMResponse) []providers.Message { + if response == nil || len(response.ToolCalls) == 0 { + return nil + } + results := make([]providers.Message, 0, len(response.ToolCalls)) + for _, tc := range response.ToolCalls { + argsJSON, _ := json.Marshal(tc.Arguments) + logger.InfoCF("agent", logger.C0172, map[string]interface{}{ + "tool": tc.Name, + "args": truncate(string(argsJSON), 200), + "iteration": iteration, + }) + execArgs := withToolContextArgs(tc.Name, tc.Arguments, cfg.toolChannel, cfg.toolChatID) + toolResult, toolErr := al.executeToolCall(cfg.ctx, tc.Name, execArgs, cfg.toolChannel, cfg.toolChatID) + if toolErr != nil { + toolResult = fmt.Sprintf("Error: %v", toolErr) + } + results = append(results, providers.Message{ + Role: "tool", + Content: toolResult, + ToolCallID: tc.ID, + }) + } + return results +} + +func (al *AgentLoop) requestLLMResponse(cfg llmTurnLoopConfig, activeProvider providers.LLMProvider, activeModel string, messages []providers.Message, providerToolDefs []providers.ToolDefinition, options map[string]interface{}) (*providers.LLMResponse, error) { + if cfg.enableStreaming { + if sp, ok := activeProvider.(providers.StreamingLLMProvider); ok { + streamText := "" + lastPush := time.Now().Add(-time.Second) + return sp.ChatStream(cfg.ctx, messages, providerToolDefs, activeModel, options, func(delta string) { + if strings.TrimSpace(delta) == "" { + return + } + streamText += delta + if time.Since(lastPush) < 450*time.Millisecond { + return + } + if !shouldFlushTelegramStreamSnapshot(streamText) { + return + } + lastPush = time.Now() + replyID := "" + if cfg.triggerMsg.Metadata != nil { + replyID = cfg.triggerMsg.Metadata["message_id"] + } + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: cfg.toolChannel, + ChatID: cfg.toolChatID, + Content: streamText, + Action: "stream", + ReplyToID: replyID, + }) + al.markSessionStreamed(cfg.sessionKey) + }) + } + } + return activeProvider.Chat(cfg.ctx, messages, providerToolDefs, activeModel, options) +} + +func (al *AgentLoop) runLLMTurnLoop(cfg llmTurnLoopConfig) (llmTurnLoopResult, error) { + result := llmTurnLoopResult{ + messages: append([]providers.Message(nil), cfg.messages...), + pendingPersist: make([]providers.Message, 0, 16), + } + maxAllowed := al.maxIterations + if maxAllowed < 1 { + maxAllowed = 1 + } + toolDefs := al.filteredToolDefinitionsForContext(cfg.ctx) + providerToolDefs := al.buildProviderToolDefs(toolDefs) + result.messages = injectResponsesMediaParts(result.messages, cfg.media, cfg.mediaItems) + + for result.iteration < maxAllowed { + result.iteration++ + activeProvider, activeModel, providerName, err := al.activeProviderForSession(cfg.sessionKey) + if err != nil { + logger.ErrorCF("agent", cfg.errorLogCode, map[string]interface{}{ + "iteration": result.iteration, + "error": err.Error(), + }) + return result, fmt.Errorf("resolve active provider: %w", err) + } + if activeProvider == nil { + return result, fmt.Errorf("active provider unavailable for session %s", strings.TrimSpace(cfg.sessionKey)) + } + + maxTokens := al.maxTokensForProvider(providerName) + temperature := al.temperatureForProvider(providerName) + logLLMTurnRequest(result.iteration, al.maxIterations, providerName, activeModel, result.messages, providerToolDefs, maxTokens, temperature) + + options := al.buildResponsesOptions(cfg.sessionKey, int64(maxTokens), temperature) + response, err := al.requestLLMResponse(cfg, activeProvider, activeModel, result.messages, providerToolDefs, options) + + if err != nil { + if fb, _, ferr := al.tryFallbackProviders(cfg.ctx, cfg.triggerMsg, result.messages, providerToolDefs, err); ferr == nil && fb != nil { + response = fb + err = nil + } else { + err = ferr + } + } + if err != nil { + logger.ErrorCF("agent", cfg.errorLogCode, map[string]interface{}{ + "iteration": result.iteration, + "error": err.Error(), + }) + return result, fmt.Errorf("LLM call failed: %w", err) + } + + if len(response.ToolCalls) == 0 { + result.finalContent = response.Content + if cfg.logDirectResponse { + logLLMDirectResponse(result.iteration, result.finalContent) + } + return result, nil + } + + logLLMToolCalls(result.iteration, response.ToolCalls) + + assistantMsg := buildAssistantToolCallMessage(response) + result.messages = append(result.messages, assistantMsg) + result.pendingPersist = append(result.pendingPersist, assistantMsg) + result.hasToolActivity = true + if maxAllowed < result.iteration+al.maxIterations { + maxAllowed = result.iteration + al.maxIterations + } + + for _, toolResultMsg := range al.executeResponseToolCalls(cfg, result.iteration, response) { + result.messages = append(result.messages, toolResultMsg) + result.pendingPersist = append(result.pendingPersist, toolResultMsg) + } + } + + return result, nil +} + +func (al *AgentLoop) logInboundMessageStart(msg bus.InboundMessage) { + logger.InfoCF("agent", logger.C0171, map[string]interface{}{ + "channel": msg.Channel, + "chat_id": msg.ChatID, + "sender_id": msg.SenderID, + "session_key": msg.SessionKey, + "preview": truncate(msg.Content, 80), + }) +} + +func (al *AgentLoop) prepareUserMessageContext(msg bus.InboundMessage, memoryNamespace string) ([]providers.Message, string) { + history := al.sessions.GetHistory(msg.SessionKey) + summary := al.sessions.GetSummary(msg.SessionKey) + al.sessions.AddMessage(msg.SessionKey, "user", msg.Content) + if explicitPref := ExtractLanguagePreference(msg.Content); explicitPref != "" { + al.sessions.SetPreferredLanguage(msg.SessionKey, explicitPref) + } + preferredLang, lastLang := al.sessions.GetLanguagePreferences(msg.SessionKey) + responseLang := DetectResponseLanguage(msg.Content, preferredLang, lastLang) + messages := al.contextBuilder.BuildMessagesWithMemoryNamespace( + history, + summary, + msg.Content, + nil, + msg.Channel, + msg.ChatID, + responseLang, + memoryNamespace, + ) + return messages, responseLang +} + +func (al *AgentLoop) finalizeUserMessage(sessionKey, responseLang string, pendingPersist []providers.Message, finalContent string) { + for _, persisted := range pendingPersist { + al.sessions.AddMessageFull(sessionKey, persisted) + } + al.sessions.AddMessageFull(sessionKey, providers.Message{ + Role: "assistant", + Content: finalContent, + }) + al.sessions.SetLastLanguage(sessionKey, responseLang) + al.compactSessionIfNeeded(sessionKey) + _ = al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) +} + +func (al *AgentLoop) prepareSystemMessageContext(sessionKey string, msg bus.InboundMessage, originChannel, originChatID string) ([]providers.Message, string) { + history := al.sessions.GetHistory(sessionKey) + summary := al.sessions.GetSummary(sessionKey) + preferredLang, lastLang := al.sessions.GetLanguagePreferences(sessionKey) + responseLang := DetectResponseLanguage(msg.Content, preferredLang, lastLang) + messages := al.contextBuilder.BuildMessages( + history, + summary, + msg.Content, + nil, + originChannel, + originChatID, + responseLang, + ) + return messages, responseLang +} + +func (al *AgentLoop) finalizeSystemMessage(sessionKey, responseLang string, msg bus.InboundMessage, pendingPersist []providers.Message, finalContent string) { + al.sessions.AddMessage(sessionKey, "user", fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content)) + for _, persisted := range pendingPersist { + al.sessions.AddMessageFull(sessionKey, persisted) + } + al.sessions.AddMessageFull(sessionKey, providers.Message{ + Role: "assistant", + Content: finalContent, + }) + al.sessions.SetLastLanguage(sessionKey, responseLang) + al.compactSessionIfNeeded(sessionKey) + _ = al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) +} + +func (al *AgentLoop) startSpecTaskForMessage(msg bus.InboundMessage) specCodingTaskRef { + specTaskRef := specCodingTaskRef{} + if err := al.maybeEnsureSpecCodingDocs(msg.Content); err != nil { + logger.WarnCF("agent", logger.C0172, map[string]interface{}{ + "session_key": msg.SessionKey, + "error": err.Error(), + }) + } + taskRef, err := al.maybeStartSpecCodingTask(msg.Content) + if err != nil { + logger.WarnCF("agent", logger.C0172, map[string]interface{}{ + "session_key": msg.SessionKey, + "error": err.Error(), + }) + return specTaskRef + } + return normalizeSpecCodingTaskRef(taskRef) +} + +func (al *AgentLoop) reopenSpecTaskOnError(specTaskRef specCodingTaskRef, msg bus.InboundMessage, err error) { + if specTaskRef.Summary == "" || err == nil { + return + } + if rerr := al.maybeReopenSpecCodingTask(specTaskRef, msg.Content, err.Error()); rerr != nil { + logger.WarnCF("agent", logger.C0172, map[string]interface{}{ + "session_key": msg.SessionKey, + "error": rerr.Error(), + }) + } +} + +func (al *AgentLoop) completeSpecTaskOnSuccess(specTaskRef specCodingTaskRef, msg bus.InboundMessage, output string) { + if specTaskRef.Summary == "" { + return + } + if err := al.maybeCompleteSpecCodingTask(specTaskRef, output); err != nil { + logger.WarnCF("agent", logger.C0172, map[string]interface{}{ + "session_key": msg.SessionKey, + "error": err.Error(), + }) + } +} + +func (al *AgentLoop) recoverFinalContentAfterToolCalls(ctx context.Context, sessionKey string, messages []providers.Message, hasToolActivity bool) string { + if !hasToolActivity { + return "" + } + activeProvider, activeModel, providerName, err := al.activeProviderForSession(sessionKey) + if err != nil { + logger.WarnCF("agent", logger.C0172, map[string]interface{}{ + "session_key": sessionKey, + "error": err.Error(), + }) + return "" + } + if activeProvider == nil { + return "" + } + options := al.buildResponsesOptionsForProvider(sessionKey, providerName, int64(al.maxTokensForProvider(providerName)), 0.2) + forced, ferr := activeProvider.Chat(ctx, messages, nil, activeModel, options) + if ferr != nil || forced == nil { + if ferr != nil { + logger.WarnCF("agent", logger.C0172, map[string]interface{}{ + "session_key": sessionKey, + "error": ferr.Error(), + }) + } + return "" + } + return forced.Content +} + +func sanitizeUserVisibleContent(finalContent string, iteration int) string { + userContent := thinkTagPattern.ReplaceAllString(finalContent, "") + if userContent == "" && finalContent != "" && iteration == 1 { + return "Thinking process completed." + } + return userContent +} + +func (al *AgentLoop) finalizeUserTurnResponse(ctx context.Context, msg bus.InboundMessage, responseLang string, loopResult llmTurnLoopResult) (string, string) { + finalContent := loopResult.finalContent + if finalContent == "" { + if recovered := al.recoverFinalContentAfterToolCalls(ctx, msg.SessionKey, loopResult.messages, loopResult.hasToolActivity); recovered != "" { + finalContent = recovered + } + } + userContent := sanitizeUserVisibleContent(finalContent, loopResult.iteration) + al.finalizeUserMessage(msg.SessionKey, responseLang, loopResult.pendingPersist, userContent) + return finalContent, userContent +} + +func (al *AgentLoop) maybeHandleAutoRoute(ctx context.Context, msg bus.InboundMessage, specTaskRef specCodingTaskRef) (string, error, bool) { + routed, ok, routeErr := al.maybeAutoRoute(ctx, msg) + if !ok { + return "", nil, false + } + if routeErr != nil { + al.reopenSpecTaskOnError(specTaskRef, msg, routeErr) + return routed, routeErr, true + } + al.completeSpecTaskOnSuccess(specTaskRef, msg, routed) + return routed, nil, true +} + func loadHeartbeatAckToken(workspace string) string { workspace = strings.TrimSpace(workspace) if workspace == "" { @@ -879,282 +1384,37 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) } defer release() al.syncSessionDefaultProvider(msg.SessionKey) - // Add message preview to log - preview := truncate(msg.Content, 80) - logger.InfoCF("agent", logger.C0171, - map[string]interface{}{ - "channel": msg.Channel, - "chat_id": msg.ChatID, - "sender_id": msg.SenderID, - "session_key": msg.SessionKey, - "preview": preview, - }) + al.logInboundMessageStart(msg) // Route system messages to processSystemMessage if msg.Channel == "system" { return al.processSystemMessage(ctx, msg) } - specTaskRef := specCodingTaskRef{} - if err := al.maybeEnsureSpecCodingDocs(msg.Content); err != nil { - logger.WarnCF("agent", logger.C0172, map[string]interface{}{ - "session_key": msg.SessionKey, - "error": err.Error(), - }) - } - if taskRef, err := al.maybeStartSpecCodingTask(msg.Content); err != nil { - logger.WarnCF("agent", logger.C0172, map[string]interface{}{ - "session_key": msg.SessionKey, - "error": err.Error(), - }) - } else { - specTaskRef = normalizeSpecCodingTaskRef(taskRef) - } - if routed, ok, routeErr := al.maybeAutoRoute(ctx, msg); ok { - if routeErr != nil && specTaskRef.Summary != "" { - if err := al.maybeReopenSpecCodingTask(specTaskRef, msg.Content, routeErr.Error()); err != nil { - logger.WarnCF("agent", logger.C0172, map[string]interface{}{ - "session_key": msg.SessionKey, - "error": err.Error(), - }) - } - } - if routeErr == nil && specTaskRef.Summary != "" { - if err := al.maybeCompleteSpecCodingTask(specTaskRef, routed); err != nil { - logger.WarnCF("agent", logger.C0172, map[string]interface{}{ - "session_key": msg.SessionKey, - "error": err.Error(), - }) - } - } + specTaskRef := al.startSpecTaskForMessage(msg) + if routed, routeErr, handled := al.maybeHandleAutoRoute(ctx, msg, specTaskRef); handled { return routed, routeErr } - history := al.sessions.GetHistory(msg.SessionKey) - summary := al.sessions.GetSummary(msg.SessionKey) - if explicitPref := ExtractLanguagePreference(msg.Content); explicitPref != "" { - al.sessions.SetPreferredLanguage(msg.SessionKey, explicitPref) - } - preferredLang, lastLang := al.sessions.GetLanguagePreferences(msg.SessionKey) - responseLang := DetectResponseLanguage(msg.Content, preferredLang, lastLang) + messages, responseLang := al.prepareUserMessageContext(msg, memoryNamespace) - messages := al.contextBuilder.BuildMessagesWithMemoryNamespace( - history, - summary, - msg.Content, - nil, - msg.Channel, - msg.ChatID, - responseLang, - memoryNamespace, - ) - - iteration := 0 - var finalContent string - hasToolActivity := false - lastToolOutputs := make([]string, 0, 4) - maxAllowed := al.maxIterations - if maxAllowed < 1 { - maxAllowed = 1 - } - for iteration < maxAllowed { - iteration++ - - logger.DebugCF("agent", logger.C0151, - map[string]interface{}{ - "iteration": iteration, - "max": al.maxIterations, - }) - - toolDefs := al.filteredToolDefinitionsForContext(ctx) - providerToolDefs := al.buildProviderToolDefs(toolDefs) - - // Log LLM request details - logger.DebugCF("agent", logger.C0152, - map[string]interface{}{ - "iteration": iteration, - "model": al.model, - "messages_count": len(messages), - "tools_count": len(providerToolDefs), - "max_tokens": 8192, - "temperature": 0.7, - "system_prompt_len": len(messages[0].Content), - }) - - // Log full messages (detailed) - logger.DebugCF("agent", logger.C0153, - map[string]interface{}{ - "iteration": iteration, - "messages_json": formatMessagesForLog(messages), - "tools_json": formatToolsForLog(providerToolDefs), - }) - - messages = injectResponsesMediaParts(messages, msg.Media, msg.MediaItems) - options := al.buildResponsesOptions(msg.SessionKey, 8192, 0.7) - var response *providers.LLMResponse - var err error - if msg.Channel == "telegram" && al.telegramStreaming { - if sp, ok := al.provider.(providers.StreamingLLMProvider); ok { - streamText := "" - lastPush := time.Now().Add(-time.Second) - response, err = sp.ChatStream(ctx, messages, providerToolDefs, al.model, options, func(delta string) { - if strings.TrimSpace(delta) == "" { - return - } - streamText += delta - if time.Since(lastPush) < 450*time.Millisecond { - return - } - if !shouldFlushTelegramStreamSnapshot(streamText) { - return - } - lastPush = time.Now() - replyID := "" - if msg.Metadata != nil { - replyID = msg.Metadata["message_id"] - } - // Stream with formatted rendering once snapshot is syntactically safe. - al.bus.PublishOutbound(bus.OutboundMessage{Channel: msg.Channel, ChatID: msg.ChatID, Content: streamText, Action: "stream", ReplyToID: replyID}) - al.markSessionStreamed(msg.SessionKey) - }) - } else { - response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, options) - } - } else { - response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, options) - } - - if err != nil { - if fb, _, ferr := al.tryFallbackProviders(ctx, msg, messages, providerToolDefs, options, err); ferr == nil && fb != nil { - response = fb - err = nil - } else { - err = ferr - } - } - if err != nil { - logger.ErrorCF("agent", logger.C0155, - map[string]interface{}{ - "iteration": iteration, - "error": err.Error(), - }) - if specTaskRef.Summary != "" { - if rerr := al.maybeReopenSpecCodingTask(specTaskRef, msg.Content, err.Error()); rerr != nil { - logger.WarnCF("agent", logger.C0172, map[string]interface{}{ - "session_key": msg.SessionKey, - "error": rerr.Error(), - }) - } - } - return "", fmt.Errorf("LLM call failed: %w", err) - } - - if len(response.ToolCalls) == 0 { - finalContent = response.Content - logger.InfoCF("agent", logger.C0156, - map[string]interface{}{ - "iteration": iteration, - "content_chars": len(finalContent), - }) - break - } - - toolNames := make([]string, 0, len(response.ToolCalls)) - for _, tc := range response.ToolCalls { - toolNames = append(toolNames, tc.Name) - } - logger.InfoCF("agent", logger.C0157, - map[string]interface{}{ - "tools": toolNames, - "count": len(toolNames), - "iteration": iteration, - }) - - assistantMsg := providers.Message{ - Role: "assistant", - Content: response.Content, - } - - for _, tc := range response.ToolCalls { - argumentsJSON, _ := json.Marshal(tc.Arguments) - assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ - ID: tc.ID, - Type: "function", - Function: &providers.FunctionCall{ - Name: tc.Name, - Arguments: string(argumentsJSON), - }, - }) - } - messages = append(messages, assistantMsg) - // Persist assistant message with tool calls. - al.sessions.AddMessageFull(msg.SessionKey, assistantMsg) - - hasToolActivity = true - // Extend rolling window as long as tools keep chaining. - if maxAllowed < iteration+al.maxIterations { - maxAllowed = iteration + al.maxIterations - } - for _, tc := range response.ToolCalls { - // Log tool call with arguments preview - argsJSON, _ := json.Marshal(tc.Arguments) - argsPreview := truncate(string(argsJSON), 200) - logger.InfoCF("agent", logger.C0172, - map[string]interface{}{ - "tool": tc.Name, - "args": argsPreview, - "iteration": iteration, - }) - - execArgs := withToolContextArgs(tc.Name, tc.Arguments, msg.Channel, msg.ChatID) - result, err := al.executeToolCall(ctx, tc.Name, execArgs, msg.Channel, msg.ChatID) - if err != nil { - result = fmt.Sprintf("Error: %v", err) - } - if len(lastToolOutputs) < 4 { - lastToolOutputs = append(lastToolOutputs, fmt.Sprintf("%s: %s", tc.Name, truncate(strings.ReplaceAll(result, "\n", " "), 180))) - } - toolResultMsg := providers.Message{ - Role: "tool", - Content: result, - ToolCallID: tc.ID, - } - messages = append(messages, toolResultMsg) - // Persist tool result message. - al.sessions.AddMessageFull(msg.SessionKey, toolResultMsg) - } - } - - if finalContent == "" && hasToolActivity { - forced, ferr := al.provider.Chat(ctx, messages, nil, al.model, map[string]interface{}{"max_tokens": 8192, "temperature": 0.2}) - if ferr == nil && forced != nil && forced.Content != "" { - finalContent = forced.Content - } - } - - // Filter out ... content from user-facing response - // Keep full content in debug logs if needed, but remove from final output - re := regexp.MustCompile(`(?s).*?`) - userContent := re.ReplaceAllString(finalContent, "") - if userContent == "" && finalContent != "" { - // If only thoughts were present, maybe provide a generic "Done" or keep something? - // For now, let's assume thoughts are auxiliary and empty response is okay if tools did work. - // If no tools ran and only thoughts, user might be confused. - if iteration == 1 { - userContent = "Thinking process completed." - } - } - - al.sessions.AddMessage(msg.SessionKey, "user", msg.Content) - - // Persist full assistant response (including reasoning/tool flow outcomes when present). - al.sessions.AddMessageFull(msg.SessionKey, providers.Message{ - Role: "assistant", - Content: userContent, + loopResult, err := al.runLLMTurnLoop(llmTurnLoopConfig{ + ctx: ctx, + triggerMsg: msg, + sessionKey: msg.SessionKey, + toolChannel: msg.Channel, + toolChatID: msg.ChatID, + messages: messages, + media: msg.Media, + mediaItems: msg.MediaItems, + enableStreaming: msg.Channel == "telegram" && al.telegramStreaming, + errorLogCode: logger.C0155, + logDirectResponse: true, }) - al.sessions.SetLastLanguage(msg.SessionKey, responseLang) - al.compactSessionIfNeeded(msg.SessionKey) - - al.sessions.Save(al.sessions.GetOrCreate(msg.SessionKey)) + if err != nil { + al.reopenSpecTaskOnError(specTaskRef, msg, err) + return "", err + } + finalContent, userContent := al.finalizeUserTurnResponse(ctx, msg, responseLang, loopResult) // Log response preview (original content) responsePreview := truncate(finalContent, 120) @@ -1163,19 +1423,12 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) "channel": msg.Channel, "sender_id": msg.SenderID, "preview": responsePreview, - "iterations": iteration, + "iterations": loopResult.iteration, "final_length": len(finalContent), "user_length": len(userContent), }) - if specTaskRef.Summary != "" { - if err := al.maybeCompleteSpecCodingTask(specTaskRef, userContent); err != nil { - logger.WarnCF("agent", logger.C0172, map[string]interface{}{ - "session_key": msg.SessionKey, - "error": err.Error(), - }) - } - } + al.completeSpecTaskOnSuccess(specTaskRef, msg, userContent) al.appendDailySummaryLog(msg, userContent) return userContent, nil } @@ -1327,130 +1580,30 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe // Use the origin session for context sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID) - // Build messages with the announce content - history := al.sessions.GetHistory(sessionKey) - summary := al.sessions.GetSummary(sessionKey) - preferredLang, lastLang := al.sessions.GetLanguagePreferences(sessionKey) - responseLang := DetectResponseLanguage(msg.Content, preferredLang, lastLang) - messages := al.contextBuilder.BuildMessages( - history, - summary, - msg.Content, - nil, - originChannel, - originChatID, - responseLang, - ) + messages, responseLang := al.prepareSystemMessageContext(sessionKey, msg, originChannel, originChatID) - iteration := 0 - var finalContent string - - for iteration < al.maxIterations { - iteration++ - - toolDefs := al.filteredToolDefinitionsForContext(ctx) - providerToolDefs := al.buildProviderToolDefs(toolDefs) - - // Log LLM request details - logger.DebugCF("agent", logger.C0152, - map[string]interface{}{ - "iteration": iteration, - "model": al.model, - "messages_count": len(messages), - "tools_count": len(providerToolDefs), - "max_tokens": 8192, - "temperature": 0.7, - "system_prompt_len": len(messages[0].Content), - }) - - // Log full messages (detailed) - logger.DebugCF("agent", logger.C0153, - map[string]interface{}{ - "iteration": iteration, - "messages_json": formatMessagesForLog(messages), - "tools_json": formatToolsForLog(providerToolDefs), - }) - - options := al.buildResponsesOptions(sessionKey, 8192, 0.7) - response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, options) - - if err != nil { - if fb, _, ferr := al.tryFallbackProviders(ctx, msg, messages, providerToolDefs, options, err); ferr == nil && fb != nil { - response = fb - err = nil - } else { - err = ferr - } - } - if err != nil { - logger.ErrorCF("agent", logger.C0162, - map[string]interface{}{ - "iteration": iteration, - "error": err.Error(), - }) - return "", fmt.Errorf("LLM call failed: %w", err) - } - - if len(response.ToolCalls) == 0 { - finalContent = response.Content - break - } - - assistantMsg := providers.Message{ - Role: "assistant", - Content: response.Content, - } - - for _, tc := range response.ToolCalls { - argumentsJSON, _ := json.Marshal(tc.Arguments) - assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ - ID: tc.ID, - Type: "function", - Function: &providers.FunctionCall{ - Name: tc.Name, - Arguments: string(argumentsJSON), - }, - }) - } - messages = append(messages, assistantMsg) - // Persist assistant message with tool calls. - al.sessions.AddMessageFull(sessionKey, assistantMsg) - - for _, tc := range response.ToolCalls { - execArgs := withToolContextArgs(tc.Name, tc.Arguments, originChannel, originChatID) - result, err := al.executeToolCall(ctx, tc.Name, execArgs, originChannel, originChatID) - if err != nil { - result = fmt.Sprintf("Error: %v", err) - } - - toolResultMsg := providers.Message{ - Role: "tool", - Content: result, - ToolCallID: tc.ID, - } - messages = append(messages, toolResultMsg) - // Persist tool result message. - al.sessions.AddMessageFull(sessionKey, toolResultMsg) - } + loopResult, err := al.runLLMTurnLoop(llmTurnLoopConfig{ + ctx: ctx, + triggerMsg: msg, + sessionKey: sessionKey, + toolChannel: originChannel, + toolChatID: originChatID, + messages: messages, + errorLogCode: logger.C0162, + logDirectResponse: false, + }) + if err != nil { + return "", err } + iteration := loopResult.iteration + finalContent := loopResult.finalContent + pendingPersist := loopResult.pendingPersist if finalContent == "" { finalContent = "Background task completed." } - // Save to session with system message marker - al.sessions.AddMessage(sessionKey, "user", fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content)) - - // If finalContent has no tool calls (last LLM turn is direct text), - // earlier steps were already persisted in-loop; this stores the final reply. - al.sessions.AddMessageFull(sessionKey, providers.Message{ - Role: "assistant", - Content: finalContent, - }) - al.sessions.SetLastLanguage(sessionKey, responseLang) - al.compactSessionIfNeeded(sessionKey) - - al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) + al.finalizeSystemMessage(sessionKey, responseLang, msg, pendingPersist, finalContent) logger.InfoCF("agent", logger.C0163, map[string]interface{}{ @@ -1564,16 +1717,30 @@ func filterToolDefinitionsByContext(ctx context.Context, toolDefs []map[string]i } func (al *AgentLoop) buildResponsesOptions(sessionKey string, maxTokens int64, temperature float64) map[string]interface{} { + providerName := strings.TrimSpace(al.getSessionProvider(sessionKey)) + if providerName == "" && len(al.providerNames) > 0 { + providerName = al.providerNames[0] + } + return al.buildResponsesOptionsForProvider(sessionKey, providerName, maxTokens, temperature) +} + +func (al *AgentLoop) buildResponsesOptionsForProvider(sessionKey, providerName string, maxTokens int64, temperature float64) map[string]interface{} { + if maxTokens <= 0 { + maxTokens = int64(al.maxTokensForProvider(providerName)) + } + if math.IsNaN(temperature) { + temperature = al.temperatureForProvider(providerName) + } options := map[string]interface{}{ "max_tokens": maxTokens, "temperature": temperature, } - if strings.EqualFold(strings.TrimSpace(al.getSessionProvider(sessionKey)), "codex") { + if strings.EqualFold(strings.TrimSpace(providerName), "codex") { if key := strings.TrimSpace(sessionKey); key != "" { options["codex_execution_session"] = key } } - responsesCfg := al.responsesConfigForSession(sessionKey) + responsesCfg := al.responsesConfigForProvider(providerName) responseTools := make([]map[string]interface{}, 0, 2) if responsesCfg.WebSearchEnabled { webTool := map[string]interface{}{"type": "web_search"} @@ -1604,6 +1771,60 @@ func (al *AgentLoop) buildResponsesOptions(sessionKey string, maxTokens int64, t return options } +func (al *AgentLoop) maxTokensForProvider(name string) int { + if al == nil { + return 8192 + } + providerName := strings.TrimSpace(name) + if providerName != "" { + if limit, ok := al.providerMaxTokens[providerName]; ok && limit > 0 { + return limit + } + } + if al.maxTokens > 0 { + return al.maxTokens + } + return 8192 +} + +func (al *AgentLoop) maxTokensForSession(sessionKey string) int { + if al == nil { + return 8192 + } + name := strings.TrimSpace(al.getSessionProvider(sessionKey)) + if name == "" && len(al.providerNames) > 0 { + name = al.providerNames[0] + } + return al.maxTokensForProvider(name) +} + +func (al *AgentLoop) temperatureForProvider(name string) float64 { + if al == nil { + return 0.7 + } + providerName := strings.TrimSpace(name) + if providerName != "" { + if value, ok := al.providerTemperatures[providerName]; ok && value != 0 { + return value + } + } + if al.temperature != 0 { + return al.temperature + } + return 0.7 +} + +func (al *AgentLoop) temperatureForSession(sessionKey string) float64 { + if al == nil { + return 0.7 + } + name := strings.TrimSpace(al.getSessionProvider(sessionKey)) + if name == "" && len(al.providerNames) > 0 { + name = al.providerNames[0] + } + return al.temperatureForProvider(name) +} + func (al *AgentLoop) responsesConfigForSession(sessionKey string) config.ProviderResponsesConfig { if al == nil { return config.ProviderResponsesConfig{} @@ -1612,6 +1833,13 @@ func (al *AgentLoop) responsesConfigForSession(sessionKey string) config.Provide if name == "" && len(al.providerNames) > 0 { name = al.providerNames[0] } + return al.responsesConfigForProvider(name) +} + +func (al *AgentLoop) responsesConfigForProvider(name string) config.ProviderResponsesConfig { + if al == nil { + return config.ProviderResponsesConfig{} + } if name == "" { return config.ProviderResponsesConfig{} } diff --git a/pkg/agent/loop_codex_options_test.go b/pkg/agent/loop_codex_options_test.go index 5ca8b3d..517f302 100644 --- a/pkg/agent/loop_codex_options_test.go +++ b/pkg/agent/loop_codex_options_test.go @@ -1,6 +1,63 @@ package agent -import "testing" +import ( + "context" + "errors" + "testing" + + "github.com/YspCoder/clawgo/pkg/bus" + "github.com/YspCoder/clawgo/pkg/config" + "github.com/YspCoder/clawgo/pkg/providers" +) + +type fallbackTestProvider struct { + response *providers.LLMResponse + err error + options map[string]interface{} +} + +func (p *fallbackTestProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + p.options = map[string]interface{}{} + for k, v := range options { + p.options[k] = v + } + if p.err != nil { + return nil, p.err + } + return p.response, nil +} + +func (p *fallbackTestProvider) GetDefaultModel() string { return "fallback-model" } + +type sequenceProvider struct { + responses []*providers.LLMResponse + errs []error +} + +func (p *sequenceProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + if len(p.responses) == 0 && len(p.errs) == 0 { + return &providers.LLMResponse{Content: "ok", FinishReason: "stop"}, nil + } + resp := (*providers.LLMResponse)(nil) + if len(p.responses) > 0 { + resp = p.responses[0] + p.responses = p.responses[1:] + } + var err error + if len(p.errs) > 0 { + err = p.errs[0] + p.errs = p.errs[1:] + } + if err != nil { + return nil, err + } + if resp == nil { + return &providers.LLMResponse{Content: "ok", FinishReason: "stop"}, nil + } + return resp, nil +} + +func (p *sequenceProvider) GetDefaultModel() string { return "sequence-model" } func TestBuildResponsesOptionsAddsCodexExecutionSession(t *testing.T) { loop := &AgentLoop{ @@ -42,3 +99,159 @@ func TestSyncSessionDefaultProviderOverridesStaleSessionProvider(t *testing.T) { t.Fatalf("expected stale session provider to be replaced with current default, got %q", got) } } + +func TestSyncSessionDefaultProviderKeepsKnownSessionProvider(t *testing.T) { + loop := &AgentLoop{ + providerNames: []string{"openai", "claude"}, + sessionProvider: map[string]string{ + "chat-1": "claude", + }, + } + + loop.syncSessionDefaultProvider("chat-1") + + if got := loop.getSessionProvider("chat-1"); got != "claude" { + t.Fatalf("expected valid session provider to be preserved, got %q", got) + } +} + +func TestMaxTokensForSessionUsesProviderOverride(t *testing.T) { + loop := &AgentLoop{ + maxTokens: 4096, + providerNames: []string{"openai"}, + sessionProvider: map[string]string{ + "chat-1": "claude", + }, + providerMaxTokens: map[string]int{ + "claude": 16384, + }, + } + + if got := loop.maxTokensForSession("chat-1"); got != 16384 { + t.Fatalf("expected provider max_tokens override, got %d", got) + } +} + +func TestMaxTokensForSessionFallsBackToAgentDefault(t *testing.T) { + loop := &AgentLoop{ + maxTokens: 4096, + providerNames: []string{"openai"}, + sessionProvider: map[string]string{}, + providerMaxTokens: map[string]int{}, + } + + if got := loop.maxTokensForSession("chat-1"); got != 4096 { + t.Fatalf("expected fallback to agent default max_tokens, got %d", got) + } +} + +func TestTemperatureForSessionUsesProviderOverride(t *testing.T) { + loop := &AgentLoop{ + temperature: 0.7, + providerNames: []string{"openai"}, + sessionProvider: map[string]string{ + "chat-1": "claude", + }, + providerTemperatures: map[string]float64{ + "claude": 0.15, + }, + } + + if got := loop.temperatureForSession("chat-1"); got != 0.15 { + t.Fatalf("expected provider temperature override, got %v", got) + } +} + +func TestTemperatureForSessionFallsBackToAgentDefault(t *testing.T) { + loop := &AgentLoop{ + temperature: 0.7, + providerNames: []string{"openai"}, + sessionProvider: map[string]string{}, + providerTemperatures: map[string]float64{}, + } + + if got := loop.temperatureForSession("chat-1"); got != 0.7 { + t.Fatalf("expected fallback to agent default temperature, got %v", got) + } +} + +func TestTryFallbackProvidersUsesFallbackProviderOptionsAndPersistsSelection(t *testing.T) { + fallback := &fallbackTestProvider{ + response: &providers.LLMResponse{Content: "fallback", FinishReason: "stop"}, + } + loop := &AgentLoop{ + maxTokens: 4096, + temperature: 0.7, + providerNames: []string{"openai", "claude"}, + sessionProvider: map[string]string{"chat-1": "openai"}, + providerPool: map[string]providers.LLMProvider{"claude": fallback}, + providerChain: []providerCandidate{{name: "openai", model: "gpt-a"}, {name: "claude", model: "claude-b"}}, + providerMaxTokens: map[string]int{"claude": 16384}, + providerTemperatures: map[string]float64{"claude": 0.15}, + providerResponses: map[string]config.ProviderResponsesConfig{ + "claude": {WebSearchEnabled: true}, + }, + } + + resp, providerName, err := loop.tryFallbackProviders(context.Background(), bus.InboundMessage{SessionKey: "chat-1"}, nil, nil, errors.New("primary failed")) + if err != nil { + t.Fatalf("expected fallback success, got %v", err) + } + if resp == nil || resp.Content != "fallback" { + t.Fatalf("unexpected fallback response: %#v", resp) + } + if providerName != "claude" { + t.Fatalf("expected provider claude, got %q", providerName) + } + if got := loop.getSessionProvider("chat-1"); got != "claude" { + t.Fatalf("expected session provider to switch to fallback provider, got %q", got) + } + if got := fallback.options["max_tokens"]; got != int64(16384) { + t.Fatalf("expected fallback max_tokens 16384, got %#v", got) + } + if got := fallback.options["temperature"]; got != 0.15 { + t.Fatalf("expected fallback temperature 0.15, got %#v", got) + } + if _, ok := fallback.options["responses_tools"]; !ok { + t.Fatalf("expected fallback responses_tools to be populated") + } +} + +func TestProcessMessageDoesNotPersistPartialAssistantToolHistoryOnFailure(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = t.TempDir() + cfg.Agents.Defaults.MaxToolIterations = 2 + + provider := &sequenceProvider{ + responses: []*providers.LLMResponse{ + { + Content: "", + ToolCalls: []providers.ToolCall{ + {ID: "tool-1", Name: "read_file", Arguments: map[string]interface{}{"path": "missing.txt"}}, + }, + FinishReason: "tool_calls", + }, + }, + errs: []error{nil, errors.New("second pass failed")}, + } + + loop := NewAgentLoop(cfg, bus.NewMessageBus(), provider, nil) + _, err := loop.processMessage(context.Background(), bus.InboundMessage{ + Channel: "cli", + ChatID: "direct", + SenderID: "user", + SessionKey: "cli:direct", + Content: "read file", + }) + if err == nil { + t.Fatalf("expected processMessage error") + } + + history := loop.sessions.GetHistory("cli:direct") + if len(history) != 1 { + t.Fatalf("expected only user message persisted on failure, got %d entries: %#v", len(history), history) + } + if history[0].Role != "user" || history[0].Content != "read file" { + t.Fatalf("unexpected persisted history: %#v", history) + } +} diff --git a/pkg/api/server.go b/pkg/api/server.go index 46aaa1c..fba097d 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -25,6 +25,7 @@ import ( "sync" "time" + "github.com/YspCoder/clawgo/pkg/bus" "github.com/YspCoder/clawgo/pkg/channels" cfgpkg "github.com/YspCoder/clawgo/pkg/config" "github.com/YspCoder/clawgo/pkg/providers" @@ -50,6 +51,7 @@ type Server struct { onConfigAfter func(forceRuntimeReload bool) error onCron func(action string, args map[string]interface{}) (interface{}, error) onToolsCatalog func() interface{} + messageBus *bus.MessageBus weixinChannel *channels.WeixinChannel oauthFlowMu sync.Mutex oauthFlows map[string]*providers.OAuthPendingFlow @@ -57,6 +59,15 @@ type Server struct { extraRoutes map[string]http.Handler eventSubsMu sync.Mutex eventSubs map[*websocket.Conn]struct{} + draftMu sync.RWMutex + channelDrafts channelDraftStore +} + +type channelDraftStore struct { + Weixin *cfgpkg.WeixinConfig + Telegram *cfgpkg.TelegramConfig + Feishu *cfgpkg.FeishuConfig + weixinRuntime *channels.WeixinChannel } func NewServer(host string, port int, token string) *Server { @@ -86,6 +97,7 @@ func (s *Server) SetChatHandler(fn func(ctx context.Context, sessionKey, content func (s *Server) SetChatHistoryHandler(fn func(sessionKey string) []map[string]interface{}) { s.onChatHistory = fn } +func (s *Server) SetMessageBus(mb *bus.MessageBus) { s.messageBus = mb } func (s *Server) SetConfigAfterHook(fn func(forceRuntimeReload bool) error) { s.onConfigAfter = fn } func (s *Server) SetCronHandler(fn func(action string, args map[string]interface{}) (interface{}, error)) { s.onCron = fn @@ -117,6 +129,356 @@ func (s *Server) SetWeixinChannel(ch *channels.WeixinChannel) { } } +func cloneWeixinConfig(cfg cfgpkg.WeixinConfig) cfgpkg.WeixinConfig { + cp := cfg + cp.AllowFrom = append([]string(nil), cfg.AllowFrom...) + cp.Accounts = append([]cfgpkg.WeixinAccountConfig(nil), cfg.Accounts...) + return cp +} + +func cloneTelegramConfig(cfg cfgpkg.TelegramConfig) cfgpkg.TelegramConfig { + cp := cfg + cp.AllowFrom = append([]string(nil), cfg.AllowFrom...) + cp.AllowChats = append([]string(nil), cfg.AllowChats...) + return cp +} + +func cloneFeishuConfig(cfg cfgpkg.FeishuConfig) cfgpkg.FeishuConfig { + cp := cfg + cp.AllowFrom = append([]string(nil), cfg.AllowFrom...) + cp.AllowChats = append([]string(nil), cfg.AllowChats...) + return cp +} + +func validChannelDraftName(name string) bool { + switch strings.ToLower(strings.TrimSpace(name)) { + case "weixin", "telegram", "feishu": + return true + default: + return false + } +} + +func decodeMergedJSON[T any](current T, raw json.RawMessage) (T, error) { + out := current + if len(raw) == 0 || string(raw) == "null" { + return out, nil + } + baseBytes, err := json.Marshal(current) + if err != nil { + return out, err + } + merged := map[string]interface{}{} + if err := json.Unmarshal(baseBytes, &merged); err != nil { + return out, err + } + patch := map[string]interface{}{} + if err := json.Unmarshal(raw, &patch); err != nil { + return out, err + } + merged = mergeJSONMap(merged, patch) + mergedBytes, err := json.Marshal(merged) + if err != nil { + return out, err + } + if err := json.Unmarshal(mergedBytes, &out); err != nil { + return out, err + } + return out, nil +} + +func (s *Server) syncWeixinDraftLocked() { + if s.channelDrafts.Weixin == nil || s.channelDrafts.weixinRuntime == nil { + return + } + snapshot := s.channelDrafts.weixinRuntime.SnapshotConfig() + s.channelDrafts.Weixin = &snapshot +} + +func (s *Server) replaceWeixinDraftRuntimeLocked(cfg *cfgpkg.WeixinConfig) error { + if s.channelDrafts.weixinRuntime != nil { + _ = s.channelDrafts.weixinRuntime.Stop(context.Background()) + s.channelDrafts.weixinRuntime = nil + } + if cfg == nil || !cfg.Enabled { + return nil + } + if s.messageBus == nil { + return fmt.Errorf("message bus not configured") + } + ch, err := channels.NewWeixinChannel(cloneWeixinConfig(*cfg), s.messageBus) + if err != nil { + return err + } + if err := ch.Start(context.Background()); err != nil { + return err + } + s.channelDrafts.weixinRuntime = ch + return nil +} + +func (s *Server) clearChannelDraftsLocked() { + if s.channelDrafts.weixinRuntime != nil { + _ = s.channelDrafts.weixinRuntime.Stop(context.Background()) + } + s.channelDrafts = channelDraftStore{} +} + +func (s *Server) clearChannelDrafts() { + s.draftMu.Lock() + defer s.draftMu.Unlock() + s.clearChannelDraftsLocked() +} + +func (s *Server) effectiveWeixinRuntime(persisted cfgpkg.WeixinConfig) (cfgpkg.WeixinConfig, *channels.WeixinChannel, bool) { + s.draftMu.Lock() + defer s.draftMu.Unlock() + if s.channelDrafts.Weixin != nil { + s.syncWeixinDraftLocked() + effective := cloneWeixinConfig(*s.channelDrafts.Weixin) + return effective, s.channelDrafts.weixinRuntime, true + } + return cloneWeixinConfig(persisted), s.weixinChannel, false +} + +func (s *Server) currentChannelDraftPayload(cfg *cfgpkg.Config, channel string) map[string]interface{} { + channel = strings.ToLower(strings.TrimSpace(channel)) + payload := map[string]interface{}{ + "ok": true, + "channel": channel, + } + s.draftMu.Lock() + defer s.draftMu.Unlock() + switch channel { + case "weixin": + persisted := cloneWeixinConfig(cfg.Channels.Weixin) + var draft interface{} + effective := persisted + dirty := s.channelDrafts.Weixin != nil + if dirty { + s.syncWeixinDraftLocked() + effective = cloneWeixinConfig(*s.channelDrafts.Weixin) + draft = effective + } + payload["persisted"] = persisted + payload["draft"] = draft + payload["effective"] = effective + payload["dirty"] = dirty + payload["runtime_enabled"] = s.channelDrafts.weixinRuntime != nil && s.channelDrafts.weixinRuntime.IsRunning() + case "telegram": + persisted := cloneTelegramConfig(cfg.Channels.Telegram) + var draft interface{} + effective := persisted + dirty := s.channelDrafts.Telegram != nil + if dirty { + effective = cloneTelegramConfig(*s.channelDrafts.Telegram) + draft = effective + } + payload["persisted"] = persisted + payload["draft"] = draft + payload["effective"] = effective + payload["dirty"] = dirty + case "feishu": + persisted := cloneFeishuConfig(cfg.Channels.Feishu) + var draft interface{} + effective := persisted + dirty := s.channelDrafts.Feishu != nil + if dirty { + effective = cloneFeishuConfig(*s.channelDrafts.Feishu) + draft = effective + } + payload["persisted"] = persisted + payload["draft"] = draft + payload["effective"] = effective + payload["dirty"] = dirty + } + return payload +} + +func (s *Server) applyChannelDrafts(cfg *cfgpkg.Config) { + if cfg == nil { + return + } + s.draftMu.Lock() + defer s.draftMu.Unlock() + s.syncWeixinDraftLocked() + if s.channelDrafts.Weixin != nil { + cfg.Channels.Weixin = cloneWeixinConfig(*s.channelDrafts.Weixin) + } + if s.channelDrafts.Telegram != nil { + cfg.Channels.Telegram = cloneTelegramConfig(*s.channelDrafts.Telegram) + } + if s.channelDrafts.Feishu != nil { + cfg.Channels.Feishu = cloneFeishuConfig(*s.channelDrafts.Feishu) + } +} + +func (s *Server) handleWebUIChannelDraft(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + cfg, err := s.loadConfig() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + switch r.Method { + case http.MethodGet: + channel := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("channel"))) + if channel == "" { + writeJSON(w, map[string]interface{}{ + "ok": true, + "channels": map[string]interface{}{ + "weixin": s.currentChannelDraftPayload(cfg, "weixin"), + "telegram": s.currentChannelDraftPayload(cfg, "telegram"), + "feishu": s.currentChannelDraftPayload(cfg, "feishu"), + }, + }) + return + } + if !validChannelDraftName(channel) { + http.Error(w, "unsupported channel", http.StatusBadRequest) + return + } + writeJSON(w, s.currentChannelDraftPayload(cfg, channel)) + case http.MethodPost: + var body struct { + Channel string `json:"channel"` + Config json.RawMessage `json:"config"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + channel := strings.ToLower(strings.TrimSpace(body.Channel)) + if !validChannelDraftName(channel) { + http.Error(w, "unsupported channel", http.StatusBadRequest) + return + } + s.draftMu.Lock() + switch channel { + case "weixin": + current := cfg.Channels.Weixin + if s.channelDrafts.Weixin != nil { + s.syncWeixinDraftLocked() + current = cloneWeixinConfig(*s.channelDrafts.Weixin) + } + next, err := decodeMergedJSON(current, body.Config) + if err != nil { + s.draftMu.Unlock() + http.Error(w, "invalid weixin config", http.StatusBadRequest) + return + } + next = cloneWeixinConfig(next) + if err := s.replaceWeixinDraftRuntimeLocked(&next); err != nil { + s.draftMu.Unlock() + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + s.channelDrafts.Weixin = &next + case "telegram": + current := cfg.Channels.Telegram + if s.channelDrafts.Telegram != nil { + current = cloneTelegramConfig(*s.channelDrafts.Telegram) + } + next, err := decodeMergedJSON(current, body.Config) + if err != nil { + s.draftMu.Unlock() + http.Error(w, "invalid telegram config", http.StatusBadRequest) + return + } + next = cloneTelegramConfig(next) + s.channelDrafts.Telegram = &next + case "feishu": + current := cfg.Channels.Feishu + if s.channelDrafts.Feishu != nil { + current = cloneFeishuConfig(*s.channelDrafts.Feishu) + } + next, err := decodeMergedJSON(current, body.Config) + if err != nil { + s.draftMu.Unlock() + http.Error(w, "invalid feishu config", http.StatusBadRequest) + return + } + next = cloneFeishuConfig(next) + s.channelDrafts.Feishu = &next + } + s.draftMu.Unlock() + s.broadcastEvent(map[string]interface{}{ + "type": "channel_draft_changed", + "channel": channel, + }) + writeJSON(w, s.currentChannelDraftPayload(cfg, channel)) + case http.MethodDelete: + channel := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("channel"))) + s.draftMu.Lock() + if channel == "" { + s.clearChannelDraftsLocked() + s.draftMu.Unlock() + writeJSON(w, map[string]interface{}{"ok": true, "cleared": "all"}) + return + } + if !validChannelDraftName(channel) { + s.draftMu.Unlock() + http.Error(w, "unsupported channel", http.StatusBadRequest) + return + } + switch channel { + case "weixin": + if s.channelDrafts.weixinRuntime != nil { + _ = s.channelDrafts.weixinRuntime.Stop(context.Background()) + s.channelDrafts.weixinRuntime = nil + } + s.channelDrafts.Weixin = nil + case "telegram": + s.channelDrafts.Telegram = nil + case "feishu": + s.channelDrafts.Feishu = nil + } + s.draftMu.Unlock() + s.broadcastEvent(map[string]interface{}{ + "type": "channel_draft_changed", + "channel": channel, + }) + writeJSON(w, map[string]interface{}{"ok": true, "channel": channel, "cleared": true}) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleWebUIChannelDraftCommit(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + cfg, err := s.loadConfig() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + s.applyChannelDrafts(cfg) + if err := s.persistWebUIConfig(cfg); err != nil { + var validationErr *configValidationError + if errors.As(err, &validationErr) { + writeJSONStatus(w, http.StatusBadRequest, map[string]interface{}{ + "ok": false, + "error": validationErr.Error(), + "errors": validationErr.Fields, + }) + return + } + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + s.clearChannelDrafts() + writeJSON(w, map[string]interface{}{"ok": true, "committed": true}) +} + func (s *Server) handleWebUIEventsLive(w http.ResponseWriter, r *http.Request) { if !s.checkAuth(r) { http.Error(w, "unauthorized", http.StatusUnauthorized) @@ -225,11 +587,14 @@ func (s *Server) Start(ctx context.Context) error { mux.HandleFunc("/api/sessions", s.handleWebUISessions) mux.HandleFunc("/api/memory", s.handleWebUIMemory) mux.HandleFunc("/api/workspace_file", s.handleWebUIWorkspaceFile) + mux.HandleFunc("/api/workspace_docs", s.handleWebUIWorkspaceDocs) mux.HandleFunc("/api/tool_allowlist_groups", s.handleWebUIToolAllowlistGroups) mux.HandleFunc("/api/tools", s.handleWebUITools) mux.HandleFunc("/api/mcp/install", s.handleWebUIMCPInstall) mux.HandleFunc("/api/logs/live", s.handleWebUILogsLive) mux.HandleFunc("/api/logs/recent", s.handleWebUILogsRecent) + mux.HandleFunc("/api/channels/draft", s.handleWebUIChannelDraft) + mux.HandleFunc("/api/channels/draft/commit", s.handleWebUIChannelDraftCommit) s.extraRoutesMu.RLock() for path, handler := range s.extraRoutes { routePath := path @@ -1227,11 +1592,17 @@ func (s *Server) handleWebUIWeixinLoginStart(w http.ResponseWriter, r *http.Requ http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } - if s.weixinChannel == nil { + cfg, err := s.loadConfig() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _, ch, _ := s.effectiveWeixinRuntime(cfg.Channels.Weixin) + if ch == nil { http.Error(w, "weixin channel unavailable", http.StatusServiceUnavailable) return } - if _, err := s.weixinChannel.StartLogin(r.Context()); err != nil { + if _, err := ch.StartLogin(r.Context()); err != nil { http.Error(w, err.Error(), http.StatusBadGateway) return } @@ -1248,7 +1619,13 @@ func (s *Server) handleWebUIWeixinLoginCancel(w http.ResponseWriter, r *http.Req http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } - if s.weixinChannel == nil { + cfg, err := s.loadConfig() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _, ch, _ := s.effectiveWeixinRuntime(cfg.Channels.Weixin) + if ch == nil { http.Error(w, "weixin channel unavailable", http.StatusServiceUnavailable) return } @@ -1259,7 +1636,7 @@ func (s *Server) handleWebUIWeixinLoginCancel(w http.ResponseWriter, r *http.Req http.Error(w, "invalid json body", http.StatusBadRequest) return } - if !s.weixinChannel.CancelPendingLogin(body.LoginID) { + if !ch.CancelPendingLogin(body.LoginID) { http.Error(w, "login_id not found", http.StatusNotFound) return } @@ -1283,8 +1660,14 @@ func (s *Server) handleWebUIWeixinQR(w http.ResponseWriter, r *http.Request) { } qrCode := "" loginID := strings.TrimSpace(r.URL.Query().Get("login_id")) - if loginID != "" && s.weixinChannel != nil { - if pending := s.weixinChannel.PendingLoginByID(loginID); pending != nil { + cfg, err := s.loadConfig() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _, ch, _ := s.effectiveWeixinRuntime(cfg.Channels.Weixin) + if loginID != "" && ch != nil { + if pending := ch.PendingLoginByID(loginID); pending != nil { qrCode = fallbackString(pending.QRCodeImgContent, pending.QRCode) } } @@ -1318,7 +1701,13 @@ func (s *Server) handleWebUIWeixinAccountRemove(w http.ResponseWriter, r *http.R http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } - if s.weixinChannel == nil { + cfg, err := s.loadConfig() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _, ch, _ := s.effectiveWeixinRuntime(cfg.Channels.Weixin) + if ch == nil { http.Error(w, "weixin channel unavailable", http.StatusServiceUnavailable) return } @@ -1329,7 +1718,7 @@ func (s *Server) handleWebUIWeixinAccountRemove(w http.ResponseWriter, r *http.R http.Error(w, "invalid json body", http.StatusBadRequest) return } - if err := s.weixinChannel.RemoveAccount(body.BotID); err != nil { + if err := ch.RemoveAccount(body.BotID); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } @@ -1346,7 +1735,13 @@ func (s *Server) handleWebUIWeixinAccountDefault(w http.ResponseWriter, r *http. http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } - if s.weixinChannel == nil { + cfg, err := s.loadConfig() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _, ch, _ := s.effectiveWeixinRuntime(cfg.Channels.Weixin) + if ch == nil { http.Error(w, "weixin channel unavailable", http.StatusServiceUnavailable) return } @@ -1357,7 +1752,7 @@ func (s *Server) handleWebUIWeixinAccountDefault(w http.ResponseWriter, r *http. http.Error(w, "invalid json body", http.StatusBadRequest) return } - if err := s.weixinChannel.SetDefaultAccount(body.BotID); err != nil { + if err := ch.SetDefaultAccount(body.BotID); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } @@ -1373,25 +1768,35 @@ func (s *Server) webUIWeixinStatusPayload(ctx context.Context) (map[string]inter "error": err.Error(), }, http.StatusInternalServerError } - weixinCfg := cfg.Channels.Weixin - if s.weixinChannel == nil { + persistedCfg := cloneWeixinConfig(cfg.Channels.Weixin) + weixinCfg, ch, usingDraft := s.effectiveWeixinRuntime(persistedCfg) + if ch == nil { return map[string]interface{}{ - "ok": false, - "enabled": weixinCfg.Enabled, - "base_url": weixinCfg.BaseURL, - "error": "weixin channel unavailable", + "ok": false, + "enabled": weixinCfg.Enabled, + "config_enabled": persistedCfg.Enabled, + "runtime_enabled": false, + "draft_dirty": usingDraft, + "base_url": weixinCfg.BaseURL, + "error": "weixin channel unavailable", }, http.StatusOK } - pendingLogins, err := s.weixinChannel.RefreshLoginStatuses(ctx) + pendingLogins, err := ch.RefreshLoginStatuses(ctx) if err != nil { return map[string]interface{}{ - "ok": false, - "enabled": weixinCfg.Enabled, - "base_url": weixinCfg.BaseURL, - "error": err.Error(), + "ok": false, + "enabled": weixinCfg.Enabled, + "config_enabled": persistedCfg.Enabled, + "runtime_enabled": ch.IsRunning(), + "draft_dirty": usingDraft, + "base_url": weixinCfg.BaseURL, + "error": err.Error(), }, http.StatusOK } - accounts := s.weixinChannel.ListAccounts() + if usingDraft { + weixinCfg = ch.SnapshotConfig() + } + accounts := ch.ListAccounts() pendingPayload := make([]map[string]interface{}, 0, len(pendingLogins)) for _, pending := range pendingLogins { pendingPayload = append(pendingPayload, map[string]interface{}{ @@ -1409,10 +1814,13 @@ func (s *Server) webUIWeixinStatusPayload(ctx context.Context) (map[string]inter firstPending = pendingLogins[0] } return map[string]interface{}{ - "ok": true, - "enabled": weixinCfg.Enabled, - "base_url": fallbackString(weixinCfg.BaseURL, "https://ilinkai.weixin.qq.com"), - "pending_logins": pendingPayload, + "ok": true, + "enabled": weixinCfg.Enabled, + "config_enabled": persistedCfg.Enabled, + "runtime_enabled": ch.IsRunning(), + "draft_dirty": usingDraft, + "base_url": fallbackString(weixinCfg.BaseURL, "https://ilinkai.weixin.qq.com"), + "pending_logins": pendingPayload, "pending_login": map[string]interface{}{ "login_id": pendingString(firstPending, "login_id"), "qr_code": pendingString(firstPending, "qr_code"), @@ -2976,9 +3384,6 @@ func (s *Server) handleWebUIMemory(w http.ResponseWriter, r *http.Request) { path := strings.TrimSpace(r.URL.Query().Get("path")) if path == "" { files := make([]string, 0, 16) - if _, err := os.Stat(filepath.Join(s.workspacePath, "MEMORY.md")); err == nil { - files = append(files, "MEMORY.md") - } entries, err := os.ReadDir(memoryDir) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -2993,11 +3398,7 @@ func (s *Server) handleWebUIMemory(w http.ResponseWriter, r *http.Request) { writeJSON(w, map[string]interface{}{"ok": true, "files": files}) return } - baseDir := memoryDir - if strings.EqualFold(path, "MEMORY.md") { - baseDir = strings.TrimSpace(s.workspacePath) - } - clean, content, found, err := readRelativeTextFile(baseDir, path) + clean, content, found, err := readRelativeTextFile(memoryDir, path) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -3073,6 +3474,70 @@ func (s *Server) handleWebUIWorkspaceFile(w http.ResponseWriter, r *http.Request } } +var workspaceDocFiles = []string{ + "AGENTS.md", + "BOOT.md", + "BOOTSTRAP.md", + "HEARTBEAT.md", + "IDENTITY.md", + "MEMORY.md", + "SOUL.md", + "TOOLS.md", + "USER.md", +} + +func (s *Server) handleWebUIWorkspaceDocs(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + workspace := strings.TrimSpace(s.workspacePath) + path := strings.TrimSpace(r.URL.Query().Get("path")) + if path != "" { + if !isWorkspaceDocAllowed(path) { + http.Error(w, "invalid path", http.StatusBadRequest) + return + } + clean, content, found, err := readRelativeTextFile(workspace, path) + if err != nil { + http.Error(w, err.Error(), relativeFilePathStatus(err)) + return + } + if !found { + http.Error(w, os.ErrNotExist.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "path": clean, "content": content}) + return + } + files := make([]string, 0, len(workspaceDocFiles)) + for _, name := range workspaceDocFiles { + _, _, found, err := readRelativeTextFile(workspace, name) + if err != nil { + http.Error(w, err.Error(), relativeFilePathStatus(err)) + return + } + if !found { + continue + } + files = append(files, name) + } + writeJSON(w, map[string]interface{}{"ok": true, "files": files}) +} + +func isWorkspaceDocAllowed(name string) bool { + for _, allowed := range workspaceDocFiles { + if name == allowed { + return true + } + } + return false +} + func (s *Server) handleWebUILogsRecent(w http.ResponseWriter, r *http.Request) { if !s.checkAuth(r) { http.Error(w, "unauthorized", http.StatusUnauthorized) diff --git a/pkg/api/server_test.go b/pkg/api/server_test.go index e8191e0..ec7642e 100644 --- a/pkg/api/server_test.go +++ b/pkg/api/server_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/YspCoder/clawgo/pkg/bus" cfgpkg "github.com/YspCoder/clawgo/pkg/config" "github.com/gorilla/websocket" ) @@ -139,6 +140,100 @@ func TestHandleWebUIConfigPostSavesNormalizedConfig(t *testing.T) { } } +func TestHandleWebUIChannelDraftCommitPersistsDrafts(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + cfgPath := filepath.Join(tmp, "config.json") + cfg := cfgpkg.DefaultConfig() + if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil { + t.Fatalf("save config: %v", err) + } + + srv := NewServer("127.0.0.1", 0, "") + srv.SetConfigPath(cfgPath) + srv.SetMessageBus(bus.NewMessageBus()) + hookCalled := 0 + srv.SetConfigAfterHook(func(forceRuntimeReload bool) error { + hookCalled++ + return nil + }) + + draftReq := httptest.NewRequest(http.MethodPost, "/api/channels/draft", strings.NewReader(`{"channel":"telegram","config":{"enabled":true,"token":"bot-token","streaming":true}}`)) + draftReq.Header.Set("Content-Type", "application/json") + draftRec := httptest.NewRecorder() + srv.handleWebUIChannelDraft(draftRec, draftReq) + if draftRec.Code != http.StatusOK { + t.Fatalf("expected 200 from draft save, got %d: %s", draftRec.Code, draftRec.Body.String()) + } + + commitReq := httptest.NewRequest(http.MethodPost, "/api/channels/draft/commit", nil) + commitRec := httptest.NewRecorder() + srv.handleWebUIChannelDraftCommit(commitRec, commitReq) + if commitRec.Code != http.StatusOK { + t.Fatalf("expected 200 from draft commit, got %d: %s", commitRec.Code, commitRec.Body.String()) + } + if hookCalled != 1 { + t.Fatalf("expected reload hook once, got %d", hookCalled) + } + + updated, err := cfgpkg.LoadConfig(cfgPath) + if err != nil { + t.Fatalf("reload config: %v", err) + } + if !updated.Channels.Telegram.Enabled { + t.Fatalf("expected telegram enabled after draft commit") + } + if updated.Channels.Telegram.Token != "bot-token" { + t.Fatalf("expected telegram token to persist, got %q", updated.Channels.Telegram.Token) + } +} + +func TestHandleWebUIWeixinStatusReflectsDraftRuntime(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + cfgPath := filepath.Join(tmp, "config.json") + cfg := cfgpkg.DefaultConfig() + cfg.Channels.Weixin.Enabled = false + if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil { + t.Fatalf("save config: %v", err) + } + + srv := NewServer("127.0.0.1", 0, "") + srv.SetConfigPath(cfgPath) + srv.SetMessageBus(bus.NewMessageBus()) + + draftReq := httptest.NewRequest(http.MethodPost, "/api/channels/draft", strings.NewReader(`{"channel":"weixin","config":{"enabled":true,"base_url":"https://ilinkai.weixin.qq.com"}}`)) + draftReq.Header.Set("Content-Type", "application/json") + draftRec := httptest.NewRecorder() + srv.handleWebUIChannelDraft(draftRec, draftReq) + if draftRec.Code != http.StatusOK { + t.Fatalf("expected 200 from weixin draft save, got %d: %s", draftRec.Code, draftRec.Body.String()) + } + + statusReq := httptest.NewRequest(http.MethodGet, "/api/weixin/status", nil) + statusRec := httptest.NewRecorder() + srv.handleWebUIWeixinStatus(statusRec, statusReq) + if statusRec.Code != http.StatusOK { + t.Fatalf("expected 200 from status, got %d: %s", statusRec.Code, statusRec.Body.String()) + } + + var payload map[string]interface{} + if err := json.Unmarshal(statusRec.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode status: %v", err) + } + if payload["draft_dirty"] != true { + t.Fatalf("expected draft_dirty=true, got %#v", payload["draft_dirty"]) + } + if payload["config_enabled"] != false { + t.Fatalf("expected config_enabled=false, got %#v", payload["config_enabled"]) + } + if payload["runtime_enabled"] != true { + t.Fatalf("expected runtime_enabled=true, got %#v", payload["runtime_enabled"]) + } +} + func TestWithCORSEchoesPreflightHeaders(t *testing.T) { t.Parallel() @@ -255,13 +350,10 @@ func TestSaveProviderConfigForcesRuntimeReload(t *testing.T) { } } -func TestHandleWebUIMemoryListsAndReadsWorkspaceMemoryFile(t *testing.T) { +func TestHandleWebUIMemoryListsAndReadsMemoryDirFile(t *testing.T) { t.Parallel() tmp := t.TempDir() - if err := os.WriteFile(filepath.Join(tmp, "MEMORY.md"), []byte("# long-term\n"), 0o644); err != nil { - t.Fatalf("write workspace memory: %v", err) - } if err := os.MkdirAll(filepath.Join(tmp, "memory"), 0o755); err != nil { t.Fatalf("mkdir memory dir: %v", err) } @@ -285,11 +377,11 @@ func TestHandleWebUIMemoryListsAndReadsWorkspaceMemoryFile(t *testing.T) { if err := json.Unmarshal(listRec.Body.Bytes(), &listPayload); err != nil { t.Fatalf("decode list payload: %v", err) } - if len(listPayload.Files) < 2 || listPayload.Files[0] != "MEMORY.md" { - t.Fatalf("expected MEMORY.md in memory file list, got %+v", listPayload.Files) + if len(listPayload.Files) != 1 || listPayload.Files[0] != "2026-03-19.md" { + t.Fatalf("expected only memory dir files, got %+v", listPayload.Files) } - readReq := httptest.NewRequest(http.MethodGet, "/api/memory?path=MEMORY.md", nil) + readReq := httptest.NewRequest(http.MethodGet, "/api/memory?path=2026-03-19.md", nil) readRec := httptest.NewRecorder() srv.handleWebUIMemory(readRec, readReq) if readRec.Code != http.StatusOK { @@ -303,11 +395,71 @@ func TestHandleWebUIMemoryListsAndReadsWorkspaceMemoryFile(t *testing.T) { if err := json.Unmarshal(readRec.Body.Bytes(), &readPayload); err != nil { t.Fatalf("decode read payload: %v", err) } - if readPayload.Path != "MEMORY.md" || readPayload.Content != "# long-term\n" { + if readPayload.Path != "2026-03-19.md" || readPayload.Content != "daily\n" { t.Fatalf("unexpected memory payload: %+v", readPayload) } } +func TestHandleWebUIWorkspaceDocsListAndRead(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + if err := os.WriteFile(filepath.Join(tmp, "AGENTS.md"), []byte("agents\n"), 0o644); err != nil { + t.Fatalf("write AGENTS.md: %v", err) + } + if err := os.WriteFile(filepath.Join(tmp, "MEMORY.md"), []byte("memory\n"), 0o644); err != nil { + t.Fatalf("write MEMORY.md: %v", err) + } + + srv := NewServer("127.0.0.1", 0, "") + srv.SetWorkspacePath(tmp) + + req := httptest.NewRequest(http.MethodGet, "/api/workspace_docs", nil) + rec := httptest.NewRecorder() + srv.handleWebUIWorkspaceDocs(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + + var payload struct { + OK bool `json:"ok"` + Files []string `json:"files"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode payload: %v", err) + } + if !payload.OK { + t.Fatalf("expected ok=true, got %+v", payload) + } + if len(payload.Files) != 2 { + t.Fatalf("expected 2 existing docs, got %+v", payload.Files) + } + if payload.Files[0] != "AGENTS.md" { + t.Fatalf("unexpected first doc payload: %+v", payload.Files[0]) + } + if payload.Files[1] != "MEMORY.md" { + t.Fatalf("unexpected second doc payload: %+v", payload.Files[1]) + } + + readReq := httptest.NewRequest(http.MethodGet, "/api/workspace_docs?path=AGENTS.md", nil) + readRec := httptest.NewRecorder() + srv.handleWebUIWorkspaceDocs(readRec, readReq) + if readRec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", readRec.Code, readRec.Body.String()) + } + var readPayload struct { + OK bool `json:"ok"` + Path string `json:"path"` + Content string `json:"content"` + } + if err := json.Unmarshal(readRec.Body.Bytes(), &readPayload); err != nil { + t.Fatalf("decode read payload: %v", err) + } + if readPayload.Path != "AGENTS.md" || readPayload.Content != "agents\n" { + t.Fatalf("unexpected read payload: %+v", readPayload) + } +} + func TestHandleWebUIChatLive(t *testing.T) { t.Parallel() diff --git a/pkg/channels/weixin.go b/pkg/channels/weixin.go index c64aa90..0fba49b 100644 --- a/pkg/channels/weixin.go +++ b/pkg/channels/weixin.go @@ -389,6 +389,22 @@ func (c *WeixinChannel) ListAccounts() []WeixinAccountSnapshot { return out } +func (c *WeixinChannel) SnapshotConfig() config.WeixinConfig { + c.mu.RLock() + defer c.mu.RUnlock() + + cfgCopy := c.config + cfgCopy.AllowFrom = append([]string(nil), cfgCopy.AllowFrom...) + cfgCopy.Accounts = append([]config.WeixinAccountConfig(nil), c.accountConfigsLocked()...) + cfgCopy.DefaultBotID = strings.TrimSpace(c.defaultBotIDLocked()) + cfgCopy.BotID = "" + cfgCopy.BotToken = "" + cfgCopy.IlinkUserID = "" + cfgCopy.ContextToken = "" + cfgCopy.GetUpdatesBuf = "" + return cfgCopy +} + func (c *WeixinChannel) SetDefaultAccount(botID string) error { botID = strings.TrimSpace(botID) if botID == "" { diff --git a/pkg/config/config.go b/pkg/config/config.go index 7267ddb..cd729af 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -226,6 +226,8 @@ type ProviderConfig struct { APIKey string `json:"api_key" env:"CLAWGO_PROVIDERS_{{.Name}}_API_KEY"` APIBase string `json:"api_base" env:"CLAWGO_PROVIDERS_{{.Name}}_API_BASE"` Models []string `json:"models" env:"CLAWGO_PROVIDERS_{{.Name}}_MODELS"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` SupportsResponsesCompact bool `json:"supports_responses_compact" env:"CLAWGO_PROVIDERS_{{.Name}}_SUPPORTS_RESPONSES_COMPACT"` Auth string `json:"auth" env:"CLAWGO_PROVIDERS_{{.Name}}_AUTH"` TimeoutSec int `json:"timeout_sec" env:"CLAWGO_PROVIDERS_PROXY_TIMEOUT_SEC"` diff --git a/pkg/config/normalized.go b/pkg/config/normalized.go index 9871422..20fe9b3 100644 --- a/pkg/config/normalized.go +++ b/pkg/config/normalized.go @@ -53,6 +53,8 @@ type NormalizedRuntimeRouterConfig struct { type NormalizedRuntimeProviderConfig struct { Auth string `json:"auth,omitempty"` APIBase string `json:"api_base,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` TimeoutSec int `json:"timeout_sec,omitempty"` OAuth ProviderOAuthConfig `json:"oauth,omitempty"` RuntimePersist bool `json:"runtime_persist,omitempty"` @@ -143,6 +145,8 @@ func (c *Config) NormalizedView() NormalizedConfig { view.Runtime.Providers[name] = NormalizedRuntimeProviderConfig{ Auth: pc.Auth, APIBase: pc.APIBase, + MaxTokens: pc.MaxTokens, + Temperature: pc.Temperature, TimeoutSec: pc.TimeoutSec, OAuth: pc.OAuth, RuntimePersist: pc.RuntimePersist, @@ -232,6 +236,12 @@ func (c *Config) ApplyNormalizedView(view NormalizedConfig) { current := c.Models.Providers[name] current.Auth = strings.TrimSpace(item.Auth) current.APIBase = strings.TrimSpace(item.APIBase) + if item.MaxTokens > 0 { + current.MaxTokens = item.MaxTokens + } else if item.MaxTokens == 0 { + current.MaxTokens = 0 + } + current.Temperature = item.Temperature if item.TimeoutSec > 0 { current.TimeoutSec = item.TimeoutSec } diff --git a/pkg/config/normalized_test.go b/pkg/config/normalized_test.go index 31e26a7..520e08b 100644 --- a/pkg/config/normalized_test.go +++ b/pkg/config/normalized_test.go @@ -5,6 +5,13 @@ import "testing" func TestNormalizedViewProjectsCoreAndRuntime(t *testing.T) { cfg := DefaultConfig() cfg.Agents.Router.Enabled = true + cfg.Models.Providers["openai"] = ProviderConfig{ + APIBase: "https://api.openai.com/v1", + Models: []string{"gpt-5.4"}, + MaxTokens: 12288, + Temperature: 0.35, + TimeoutSec: 90, + } cfg.Agents.Subagents["coder"] = SubagentConfig{ Enabled: true, Role: "coding", @@ -27,4 +34,10 @@ func TestNormalizedViewProjectsCoreAndRuntime(t *testing.T) { if !view.Runtime.Router.Enabled || view.Runtime.Router.Strategy != "rules_first" { t.Fatalf("unexpected runtime router: %+v", view.Runtime.Router) } + if got := view.Runtime.Providers["openai"].MaxTokens; got != 12288 { + t.Fatalf("expected provider max_tokens in normalized runtime view, got %d", got) + } + if got := view.Runtime.Providers["openai"].Temperature; got != 0.35 { + t.Fatalf("expected provider temperature in normalized runtime view, got %v", got) + } } diff --git a/pkg/config/validate.go b/pkg/config/validate.go index 529824b..1963e8d 100644 --- a/pkg/config/validate.go +++ b/pkg/config/validate.go @@ -478,6 +478,9 @@ func validateProviderConfig(path string, p ProviderConfig) []error { if p.TimeoutSec <= 0 { errs = append(errs, fmt.Errorf("%s.timeout_sec must be > 0", path)) } + if p.MaxTokens < 0 { + errs = append(errs, fmt.Errorf("%s.max_tokens must be >= 0", path)) + } switch authMode { case "", "bearer", "oauth", "none", "hybrid": default: diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go index 054c1ef..2a8c4e5 100644 --- a/pkg/providers/codex_provider.go +++ b/pkg/providers/codex_provider.go @@ -495,9 +495,7 @@ func (p *CodexProvider) doStreamAttempt(req *http.Request, attempt authAttempt, if typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"])); typ == "response.completed" { completed = true if respObj, ok := obj["response"]; ok { - if b, err := json.Marshal(respObj); err == nil { - finalJSON = b - } + finalJSON = mergeStreamFinalJSON(finalJSON, respObj) } } } diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go index 306c591..43e301d 100644 --- a/pkg/providers/codex_provider_test.go +++ b/pkg/providers/codex_provider_test.go @@ -214,6 +214,27 @@ func TestCodexProviderChatFallsBackToHTTPStreamResponse(t *testing.T) { } } +func TestCodexProviderChatMergesLateUsageFromStreamingCompletion(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = fmt.Fprint(w, "data: {\"type\":\"response.completed\",\"response\":{\"status\":\"completed\",\"output_text\":\"hello\"}}\n\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"total_tokens\":3}}}\n\n") + })) + defer server.Close() + + provider := NewCodexProvider("codex", "test-api-key", server.URL, "gpt-5.4", false, "", 5*time.Second, nil) + resp, err := provider.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-5.4", nil) + if err != nil { + t.Fatalf("Chat error: %v", err) + } + if resp.Content != "hello" { + t.Fatalf("unexpected response content: %q", resp.Content) + } + if resp.Usage == nil || resp.Usage.PromptTokens != 1 || resp.Usage.CompletionTokens != 2 || resp.Usage.TotalTokens != 3 { + t.Fatalf("unexpected usage: %#v", resp.Usage) + } +} + func TestCodexHandleAttemptFailureMarksAPIKeyCooldown(t *testing.T) { provider := NewCodexProvider("codex-websocket-failure", "test-api-key", "", "gpt-5.4", false, "", 5*time.Second, nil) provider.handleAttemptFailure(authAttempt{kind: "api_key", token: "test-api-key"}, http.StatusTooManyRequests, []byte(`{"error":{"message":"rate limit exceeded"}}`)) diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index e8620b5..7da55ff 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -1022,15 +1022,13 @@ func (p *HTTPProvider) doStreamAttempt(req *http.Request, attempt authAttempt, o if err := json.Unmarshal([]byte(payload), &obj); err == nil { if typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"])); typ == "response.completed" { if respObj, ok := obj["response"]; ok { - if b, err := json.Marshal(respObj); err == nil { - finalJSON = b - } + finalJSON = mergeStreamFinalJSON(finalJSON, respObj) } } if choices, ok := obj["choices"]; ok { - if b, err := json.Marshal(map[string]interface{}{"choices": choices, "usage": obj["usage"]}); err == nil { - finalJSON = b - } + finalJSON = mergeStreamFinalJSON(finalJSON, map[string]interface{}{"choices": choices, "usage": obj["usage"]}) + } else if _, ok := obj["usage"]; ok && len(finalJSON) > 0 { + finalJSON = mergeStreamFinalJSON(finalJSON, map[string]interface{}{"usage": obj["usage"]}) } } } @@ -1049,6 +1047,56 @@ func (p *HTTPProvider) doStreamAttempt(req *http.Request, attempt authAttempt, o return finalJSON, resp.StatusCode, ctype, false, nil } +func mergeStreamFinalJSON(existing []byte, incoming interface{}) []byte { + if incoming == nil { + return existing + } + incomingMap, ok := incoming.(map[string]interface{}) + if !ok { + data, err := json.Marshal(incoming) + if err != nil { + return existing + } + return data + } + if len(existing) == 0 { + data, err := json.Marshal(incomingMap) + if err != nil { + return existing + } + return data + } + var merged map[string]interface{} + if err := json.Unmarshal(existing, &merged); err != nil || merged == nil { + merged = map[string]interface{}{} + } + merged = mergeStringAnyMaps(merged, incomingMap) + data, err := json.Marshal(merged) + if err != nil { + return existing + } + return data +} + +func mergeStringAnyMaps(dst, src map[string]interface{}) map[string]interface{} { + if dst == nil { + dst = map[string]interface{}{} + } + for key, value := range src { + if value == nil { + continue + } + if nestedSrc, ok := value.(map[string]interface{}); ok { + if nestedDst, ok := dst[key].(map[string]interface{}); ok { + dst[key] = mergeStringAnyMaps(nestedDst, nestedSrc) + continue + } + } + dst[key] = value + } + return dst +} + func shouldRetryOAuthQuota(status int, body []byte) bool { _, retry := classifyOAuthFailure(status, body) return retry diff --git a/pkg/providers/oauth_test.go b/pkg/providers/oauth_test.go index 78a6ca3..2856674 100644 --- a/pkg/providers/oauth_test.go +++ b/pkg/providers/oauth_test.go @@ -197,6 +197,44 @@ func TestHTTPProviderOAuthSwitchesAccountOnQuota(t *testing.T) { } } +func TestHTTPProviderOpenAICompatStreamMergesLateUsage(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/chat/completions" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: {\"choices\":[{\"index\":0,\"message\":{\"content\":\"hello\"},\"finish_reason\":\"stop\"}]}\n\n")) + _, _ = w.Write([]byte("data: {\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3}}\n\n")) + })) + defer server.Close() + + provider := NewHTTPProvider("openai", "token", server.URL+"/v1", "gpt-test", false, "api_key", 5*time.Second, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, server.URL+"/v1/chat/completions", nil) + if err != nil { + t.Fatalf("new request failed: %v", err) + } + body, status, _, _, err := provider.doStreamAttempt(req, authAttempt{kind: "api_key", token: "token"}, nil) + if err != nil { + t.Fatalf("stream attempt failed: %v", err) + } + if status != http.StatusOK { + t.Fatalf("unexpected status: %d", status) + } + resp, err := parseOpenAICompatResponse(body) + if err != nil { + t.Fatalf("parse response failed: %v", err) + } + if resp.Content != "hello" { + t.Fatalf("unexpected response content: %q", resp.Content) + } + if resp.Usage == nil || resp.Usage.PromptTokens != 1 || resp.Usage.CompletionTokens != 2 || resp.Usage.TotalTokens != 3 { + t.Fatalf("unexpected usage: %#v", resp.Usage) + } +} + func TestOAuthManagerPreRefreshesExpiringSession(t *testing.T) { t.Parallel() From b5fd246cef8fd6c1470ffd1b1bd9d15762fc2244 Mon Sep 17 00:00:00 2001 From: lpf Date: Wed, 8 Apr 2026 15:45:52 +0800 Subject: [PATCH 2/5] Fix channel draft persistence --- pkg/api/server.go | 46 ++++++++++++++++++-- pkg/api/server_test.go | 97 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+), 3 deletions(-) diff --git a/pkg/api/server.go b/pkg/api/server.go index fba097d..b17f9a9 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -241,6 +241,32 @@ func (s *Server) effectiveWeixinRuntime(persisted cfgpkg.WeixinConfig) (cfgpkg.W return cloneWeixinConfig(persisted), s.weixinChannel, false } +func (s *Server) ensureWeixinRuntimeForLogin(persisted cfgpkg.WeixinConfig) (cfgpkg.WeixinConfig, *channels.WeixinChannel, bool, error) { + s.draftMu.Lock() + defer s.draftMu.Unlock() + if s.channelDrafts.Weixin != nil { + s.syncWeixinDraftLocked() + effective := cloneWeixinConfig(*s.channelDrafts.Weixin) + return effective, s.channelDrafts.weixinRuntime, true, nil + } + if s.weixinChannel != nil { + return cloneWeixinConfig(persisted), s.weixinChannel, false, nil + } + + bootstrap := cloneWeixinConfig(persisted) + bootstrap.Enabled = true + if strings.TrimSpace(bootstrap.BaseURL) == "" { + bootstrap.BaseURL = "https://ilinkai.weixin.qq.com" + } + s.channelDrafts.Weixin = &bootstrap + if err := s.replaceWeixinDraftRuntimeLocked(&bootstrap); err != nil { + return cloneWeixinConfig(persisted), nil, true, err + } + s.syncWeixinDraftLocked() + effective := cloneWeixinConfig(*s.channelDrafts.Weixin) + return effective, s.channelDrafts.weixinRuntime, true, nil +} + func (s *Server) currentChannelDraftPayload(cfg *cfgpkg.Config, channel string) map[string]interface{} { channel = strings.ToLower(strings.TrimSpace(channel)) payload := map[string]interface{}{ @@ -765,7 +791,12 @@ func (s *Server) saveWebUIConfig(r *http.Request) error { if err := json.NewDecoder(r.Body).Decode(cfg); err != nil { return fmt.Errorf("decode config: %w", err) } - return s.persistWebUIConfig(cfg) + s.applyChannelDrafts(cfg) + if err := s.persistWebUIConfig(cfg); err != nil { + return err + } + s.clearChannelDrafts() + return nil case "normalized": cfg, err := cfgpkg.LoadConfig(s.configPath) if err != nil { @@ -776,7 +807,12 @@ func (s *Server) saveWebUIConfig(r *http.Request) error { return fmt.Errorf("decode normalized config: %w", err) } cfg.ApplyNormalizedView(view) - return s.persistWebUIConfig(cfg) + s.applyChannelDrafts(cfg) + if err := s.persistWebUIConfig(cfg); err != nil { + return err + } + s.clearChannelDrafts() + return nil default: return fmt.Errorf("unsupported config mode: %s", mode) } @@ -1597,7 +1633,11 @@ func (s *Server) handleWebUIWeixinLoginStart(w http.ResponseWriter, r *http.Requ http.Error(w, err.Error(), http.StatusInternalServerError) return } - _, ch, _ := s.effectiveWeixinRuntime(cfg.Channels.Weixin) + _, ch, _, err := s.ensureWeixinRuntimeForLogin(cfg.Channels.Weixin) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } if ch == nil { http.Error(w, "weixin channel unavailable", http.StatusServiceUnavailable) return diff --git a/pkg/api/server_test.go b/pkg/api/server_test.go index ec7642e..3ac1f9e 100644 --- a/pkg/api/server_test.go +++ b/pkg/api/server_test.go @@ -234,6 +234,103 @@ func TestHandleWebUIWeixinStatusReflectsDraftRuntime(t *testing.T) { } } +func TestHandleWebUIConfigPostPersistsChannelDrafts(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + cfgPath := filepath.Join(tmp, "config.json") + cfg := cfgpkg.DefaultConfig() + cfg.Channels.Weixin.Enabled = false + if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil { + t.Fatalf("save config: %v", err) + } + + srv := NewServer("127.0.0.1", 0, "") + srv.SetConfigPath(cfgPath) + srv.SetMessageBus(bus.NewMessageBus()) + srv.SetConfigAfterHook(func(forceRuntimeReload bool) error { return nil }) + + draftReq := httptest.NewRequest(http.MethodPost, "/api/channels/draft", strings.NewReader(`{"channel":"weixin","config":{"enabled":true,"base_url":"https://ilinkai.weixin.qq.com"}}`)) + draftReq.Header.Set("Content-Type", "application/json") + draftRec := httptest.NewRecorder() + srv.handleWebUIChannelDraft(draftRec, draftReq) + if draftRec.Code != http.StatusOK { + t.Fatalf("expected 200 from weixin draft save, got %d: %s", draftRec.Code, draftRec.Body.String()) + } + + saveReq := httptest.NewRequest(http.MethodPost, "/api/config", strings.NewReader(`{"gateway":{"host":"127.0.0.1","port":7788,"token":"abc"},"logging":{"enabled":false,"persist":false,"level":"debug","file":"logs/app.log","format":"text"},"models":{"providers":{"openai":{"api_base":"https://api.openai.com/v1","auth":"bearer","api_key":"secret","models":["gpt-5"],"timeout_sec":120}}},"tools":{"shell":{"enabled":true},"mcp":{"enabled":false}},"agents":{"defaults":{"model":{"primary":"openai/gpt-5"},"max_tool_iterations":10,"execution":{"run_state_ttl_seconds":3600,"run_state_max":128,"tool_parallel_safe_names":[],"tool_max_parallel_calls":4}},"router":{"enabled":false,"policy":{"intent_max_input_chars":2000,"max_rounds_without_user":3}},"subagents":{}},"channels":{"telegram":{"enabled":true,"token":"bot-token"}},"cron":{"enabled":false},"sentinel":{"enabled":false}}`)) + saveReq.Header.Set("Content-Type", "application/json") + saveRec := httptest.NewRecorder() + srv.handleWebUIConfig(saveRec, saveReq) + if saveRec.Code != http.StatusOK { + t.Fatalf("expected 200 from config save, got %d: %s", saveRec.Code, saveRec.Body.String()) + } + + updated, err := cfgpkg.LoadConfig(cfgPath) + if err != nil { + t.Fatalf("reload config: %v", err) + } + if !updated.Channels.Weixin.Enabled { + t.Fatalf("expected weixin enabled after config save with drafts") + } + + srv.draftMu.Lock() + defer srv.draftMu.Unlock() + if srv.channelDrafts.Weixin != nil { + t.Fatalf("expected weixin draft cleared after config save") + } +} + +func TestHandleWebUIWeixinLoginStartBootstrapsDraftRuntime(t *testing.T) { + t.Parallel() + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || r.URL.Path != "/ilink/bot/get_bot_qrcode" { + http.NotFound(w, r) + return + } + _ = json.NewEncoder(w).Encode(map[string]string{ + "qrcode": "wx-qr", + "qrcode_img_content": "wx-qr-img", + }) + })) + defer upstream.Close() + + tmp := t.TempDir() + cfgPath := filepath.Join(tmp, "config.json") + cfg := cfgpkg.DefaultConfig() + cfg.Channels.Weixin.Enabled = false + cfg.Channels.Weixin.BaseURL = upstream.URL + if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil { + t.Fatalf("save config: %v", err) + } + + srv := NewServer("127.0.0.1", 0, "") + srv.SetConfigPath(cfgPath) + srv.SetMessageBus(bus.NewMessageBus()) + + req := httptest.NewRequest(http.MethodPost, "/api/weixin/login/start", nil) + rec := httptest.NewRecorder() + srv.handleWebUIWeixinLoginStart(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 from login start, got %d: %s", rec.Code, rec.Body.String()) + } + + var payload map[string]interface{} + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode response: %v", err) + } + if payload["draft_dirty"] != true { + t.Fatalf("expected draft_dirty=true, got %#v", payload["draft_dirty"]) + } + if payload["config_enabled"] != false { + t.Fatalf("expected config_enabled=false, got %#v", payload["config_enabled"]) + } + if payload["runtime_enabled"] != true { + t.Fatalf("expected runtime_enabled=true, got %#v", payload["runtime_enabled"]) + } +} + func TestWithCORSEchoesPreflightHeaders(t *testing.T) { t.Parallel() From 36890c7ce0066681971eb28f1a26ace5992d4a53 Mon Sep 17 00:00:00 2001 From: lpf Date: Thu, 9 Apr 2026 10:10:03 +0800 Subject: [PATCH 3/5] fix(weixin): dedupe login-status refresh and remove allow_from from example --- config.example.json | 3 +- pkg/channels/weixin.go | 44 +++++++++++- pkg/channels/weixin_test.go | 131 ++++++++++++++++++++++++++++++++++++ 3 files changed, 174 insertions(+), 4 deletions(-) diff --git a/config.example.json b/config.example.json index 4eda1b1..9ec4bf8 100644 --- a/config.example.json +++ b/config.example.json @@ -149,8 +149,7 @@ "enabled": false, "base_url": "https://ilinkai.weixin.qq.com", "default_bot_id": "", - "accounts": [], - "allow_from": [] + "accounts": [] }, "telegram": { "enabled": false, diff --git a/pkg/channels/weixin.go b/pkg/channels/weixin.go index 0fba49b..5db3628 100644 --- a/pkg/channels/weixin.go +++ b/pkg/channels/weixin.go @@ -39,6 +39,7 @@ const ( weixinConfigCacheTTL = 24 * time.Hour weixinConfigRetryInitial = 2 * time.Second weixinConfigRetryMax = time.Hour + weixinLoginStatusMinGap = 1200 * time.Millisecond ) type WeixinChannel struct { @@ -63,6 +64,9 @@ type WeixinChannel struct { typingCache map[string]weixinTypingCacheEntry pauseMu sync.Mutex pauseUntil time.Time + loginStatusMu sync.Mutex + loginStatusAt time.Time + loginStatusIn chan struct{} } type weixinTypingCacheEntry struct { @@ -308,15 +312,51 @@ func (c *WeixinChannel) StartLogin(ctx context.Context) (*WeixinPendingLogin, er } func (c *WeixinChannel) RefreshLoginStatuses(ctx context.Context) ([]*WeixinPendingLogin, error) { + for { + c.loginStatusMu.Lock() + now := time.Now() + if !c.loginStatusAt.IsZero() && now.Sub(c.loginStatusAt) < weixinLoginStatusMinGap { + c.loginStatusMu.Unlock() + return c.PendingLogins(), nil + } + if wait := c.loginStatusIn; wait != nil { + c.loginStatusMu.Unlock() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-wait: + } + continue + } + wait := make(chan struct{}) + c.loginStatusIn = wait + c.loginStatusMu.Unlock() + + err := c.refreshAllLoginStatuses(ctx) + + c.loginStatusMu.Lock() + c.loginStatusAt = time.Now() + close(c.loginStatusIn) + c.loginStatusIn = nil + c.loginStatusMu.Unlock() + + if err != nil { + return nil, err + } + return c.PendingLogins(), nil + } +} + +func (c *WeixinChannel) refreshAllLoginStatuses(ctx context.Context) error { c.mu.RLock() loginIDs := append([]string(nil), c.loginOrder...) c.mu.RUnlock() for _, loginID := range loginIDs { if err := c.refreshLoginStatus(ctx, loginID); err != nil { - return nil, err + return err } } - return c.PendingLogins(), nil + return nil } func (c *WeixinChannel) PendingLogins() []*WeixinPendingLogin { diff --git a/pkg/channels/weixin_test.go b/pkg/channels/weixin_test.go index c57baff..71e7cb2 100644 --- a/pkg/channels/weixin_test.go +++ b/pkg/channels/weixin_test.go @@ -394,6 +394,137 @@ func TestWeixinGetTypingTicketCachesAndFallsBack(t *testing.T) { } } +func TestWeixinRefreshLoginStatusesDeduplicatesConcurrentCalls(t *testing.T) { + mb := bus.NewMessageBus() + ch, err := NewWeixinChannel(config.WeixinConfig{ + BaseURL: "https://ilinkai.weixin.qq.com", + }, mb) + if err != nil { + t.Fatalf("new weixin channel: %v", err) + } + ch.pendingLogins["login-1"] = &WeixinPendingLogin{ + LoginID: "login-1", + QRCode: "code-1", + Status: "wait", + UpdatedAt: time.Now().UTC().Format(time.RFC3339), + } + ch.loginOrder = []string{"login-1"} + + var calls int + var callsMu sync.Mutex + started := make(chan struct{}, 1) + release := make(chan struct{}) + ch.httpClient = &http.Client{Transport: weixinRoundTripFunc(func(req *http.Request) (*http.Response, error) { + if req.URL.Path == "/ilink/bot/get_qrcode_status" { + callsMu.Lock() + calls++ + callsMu.Unlock() + select { + case started <- struct{}{}: + default: + } + <-release + body := `{"status":"wait"}` + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + }, nil + } + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(strings.NewReader("not found")), + Header: make(http.Header), + }, nil + })} + + errCh := make(chan error, 2) + go func() { + _, callErr := ch.RefreshLoginStatuses(context.Background()) + errCh <- callErr + }() + select { + case <-started: + case <-time.After(time.Second): + t.Fatalf("timed out waiting for first refresh request") + } + + go func() { + _, callErr := ch.RefreshLoginStatuses(context.Background()) + errCh <- callErr + }() + time.Sleep(50 * time.Millisecond) + + callsMu.Lock() + gotCalls := calls + callsMu.Unlock() + if gotCalls != 1 { + t.Fatalf("expected exactly 1 upstream status call while refresh in-flight, got %d", gotCalls) + } + + close(release) + for i := 0; i < 2; i++ { + if callErr := <-errCh; callErr != nil { + t.Fatalf("refresh call %d returned error: %v", i+1, callErr) + } + } +} + +func TestWeixinRefreshLoginStatusesHonorsMinGap(t *testing.T) { + mb := bus.NewMessageBus() + ch, err := NewWeixinChannel(config.WeixinConfig{ + BaseURL: "https://ilinkai.weixin.qq.com", + }, mb) + if err != nil { + t.Fatalf("new weixin channel: %v", err) + } + ch.pendingLogins["login-1"] = &WeixinPendingLogin{ + LoginID: "login-1", + QRCode: "code-1", + Status: "wait", + UpdatedAt: time.Now().UTC().Format(time.RFC3339), + } + ch.loginOrder = []string{"login-1"} + + var calls int + ch.httpClient = &http.Client{Transport: weixinRoundTripFunc(func(req *http.Request) (*http.Response, error) { + if req.URL.Path == "/ilink/bot/get_qrcode_status" { + calls++ + body := `{"status":"wait"}` + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + }, nil + } + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(strings.NewReader("not found")), + Header: make(http.Header), + }, nil + })} + + if _, err := ch.RefreshLoginStatuses(context.Background()); err != nil { + t.Fatalf("first refresh: %v", err) + } + if _, err := ch.RefreshLoginStatuses(context.Background()); err != nil { + t.Fatalf("second refresh: %v", err) + } + if calls != 1 { + t.Fatalf("expected second refresh within min gap to reuse cached result, calls=%d", calls) + } + + ch.loginStatusMu.Lock() + ch.loginStatusAt = time.Now().Add(-weixinLoginStatusMinGap - time.Millisecond) + ch.loginStatusMu.Unlock() + if _, err := ch.RefreshLoginStatuses(context.Background()); err != nil { + t.Fatalf("third refresh: %v", err) + } + if calls != 2 { + t.Fatalf("expected refresh after min gap to hit upstream again, calls=%d", calls) + } +} + func TestPollDelayForAttempt(t *testing.T) { if got := pollDelayForAttempt(1); got != weixinRetryDelay { t.Fatalf("attempt 1 delay = %s", got) From fac235db805c4b788a7cc40f70c67c15ccda2083 Mon Sep 17 00:00:00 2001 From: lpf Date: Mon, 13 Apr 2026 13:41:01 +0800 Subject: [PATCH 4/5] feat: harden jsonl runtime reliability --- cmd/cmd_gateway.go | 36 +- cmd/cmd_onboard.go | 5 +- cmd/main.go | 4 - embedded_assets.go | 8 + pkg/agent/context.go | 6 +- pkg/agent/context_spec_test.go | 15 + pkg/agent/loop.go | 833 ++++++++++--- pkg/agent/loop_codex_options_test.go | 5 +- pkg/agent/reliability_test.go | 348 ++++++ pkg/api/server.go | 112 +- pkg/api/server_test.go | 67 + pkg/config/config.go | 37 +- pkg/config/validate.go | 12 + pkg/jsonlog/store.go | 96 ++ pkg/providers/execution.go | 80 +- pkg/providers/http_provider.go | 38 +- pkg/providers/iflow_provider.go | 2 +- pkg/providers/oauth_test.go | 2 +- pkg/providers/openai_compat_provider.go | 2 +- pkg/providers/types.go | 13 + pkg/session/manager.go | 1526 ++++++++++++++++++++--- pkg/session/manager_test.go | 198 +++ pkg/tools/runtime_types.go | 21 +- pkg/tools/session_search.go | 103 ++ pkg/tools/spawn.go | 28 +- pkg/tools/subagent.go | 295 +++-- pkg/tools/subagent_budget_test.go | 123 ++ pkg/tools/subagent_mailbox.go | 412 ++++-- pkg/tools/subagent_mailbox_test.go | 81 ++ pkg/tools/subagent_profile.go | 112 +- pkg/tools/subagent_runtime_context.go | 51 + pkg/tools/subagent_store.go | 382 ++++-- pkg/tools/subagent_store_test.go | 70 ++ pkg/tools/tool_allowlist_groups.go | 4 +- 34 files changed, 4370 insertions(+), 757 deletions(-) create mode 100644 embedded_assets.go create mode 100644 pkg/agent/reliability_test.go create mode 100644 pkg/jsonlog/store.go create mode 100644 pkg/tools/session_search.go create mode 100644 pkg/tools/subagent_budget_test.go create mode 100644 pkg/tools/subagent_mailbox_test.go create mode 100644 pkg/tools/subagent_runtime_context.go create mode 100644 pkg/tools/subagent_store_test.go diff --git a/cmd/cmd_gateway.go b/cmd/cmd_gateway.go index d2adee4..74c495a 100644 --- a/cmd/cmd_gateway.go +++ b/cmd/cmd_gateway.go @@ -151,8 +151,11 @@ func gatewayCmd() { } return loop.ProcessDirect(cctx, content, sessionKey) }) - registryServer.SetChatHistoryHandler(func(sessionKey string) []map[string]interface{} { - h := loop.GetSessionHistory(sessionKey) + registryServer.SetChatHistoryHandler(func(query api.ChatHistoryQuery) []map[string]interface{} { + h := loop.GetSessionHistory(query.Session) + if query.Around > 0 || query.Before > 0 || query.After > 0 || query.Limit > 0 { + h = loop.GetSessionHistoryWindow(query.Session, query.Around, query.Before, query.After, query.Limit) + } out := make([]map[string]interface{}, 0, len(h)) for _, m := range h { entry := map[string]interface{}{"role": m.Role, "content": m.Content} @@ -166,6 +169,35 @@ func gatewayCmd() { } return out }) + registryServer.SetSessionSearchHandler(func(query api.SessionSearchQuery) []map[string]interface{} { + excludeKey := "" + if query.ExcludeCurrent { + excludeKey = strings.TrimSpace(query.Session) + } + results := loop.SearchSessions(query.Query, query.Kinds, excludeKey, query.Limit) + out := make([]map[string]interface{}, 0, len(results)) + for _, item := range results { + entry := map[string]interface{}{ + "key": item.Key, + "kind": item.Kind, + "updated_at": item.UpdatedAt.UnixMilli(), + "summary": item.Summary, + "score": item.Score, + } + snippets := make([]map[string]interface{}, 0, len(item.Snippets)) + for _, snippet := range item.Snippets { + snippets = append(snippets, map[string]interface{}{ + "seq": snippet.Seq, + "role": snippet.Role, + "segment": snippet.Segment, + "content": snippet.Content, + }) + } + entry["snippets"] = snippets + out = append(out, entry) + } + return out + }) registryServer.SetToolsCatalogHandler(func() interface{} { return loop.GetToolCatalog() }) diff --git a/cmd/cmd_onboard.go b/cmd/cmd_onboard.go index 6e15d05..2050587 100644 --- a/cmd/cmd_onboard.go +++ b/cmd/cmd_onboard.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" + clawgoassets "github.com/YspCoder/clawgo" "github.com/YspCoder/clawgo/pkg/config" ) @@ -42,7 +43,7 @@ func copyEmbeddedToTarget(targetDir string, overwrite func(relPath string) bool) return fmt.Errorf("failed to create target directory: %w", err) } - return fs.WalkDir(embeddedFiles, "workspace", func(path string, d fs.DirEntry, err error) error { + return fs.WalkDir(clawgoassets.WorkspaceTemplates, "workspace", func(path string, d fs.DirEntry, err error) error { if err != nil { return err } @@ -50,7 +51,7 @@ func copyEmbeddedToTarget(targetDir string, overwrite func(relPath string) bool) return nil } - data, err := embeddedFiles.ReadFile(path) + data, err := clawgoassets.WorkspaceTemplates.ReadFile(path) if err != nil { return fmt.Errorf("failed to read embedded file %s: %w", path, err) } diff --git a/cmd/main.go b/cmd/main.go index 6e4e429..69c85bc 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -7,7 +7,6 @@ package main import ( - "embed" "errors" "fmt" "os" @@ -16,9 +15,6 @@ import ( "github.com/YspCoder/clawgo/pkg/logger" ) -//go:embed workspace -var embeddedFiles embed.FS - var version = "0.0.2" var buildTime = "unknown" diff --git a/embedded_assets.go b/embedded_assets.go new file mode 100644 index 0000000..a753373 --- /dev/null +++ b/embedded_assets.go @@ -0,0 +1,8 @@ +package clawgoassets + +import "embed" + +// WorkspaceTemplates exposes the bundled workspace scaffold used by onboard. +// +//go:embed workspace +var WorkspaceTemplates embed.FS diff --git a/pkg/agent/context.go b/pkg/agent/context.go index 6652d93..589b16a 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -73,6 +73,10 @@ Your workspace is at: %s - Active project spec docs (when present): %s/{spec.md,tasks.md,checklist.md} - Keep spec.md as project scope / decisions, tasks.md as execution plan, checklist.md as final verification gate +## Session Recall +- When the user refers to previous conversations, earlier project work, or past decisions, prefer session_search before guessing from memory +- Use memory_search for durable notes in MEMORY files, and session_search for historical chat transcripts + %s`, now, runtime, workspacePath, workspacePath, workspacePath, workspacePath, cb.projectRootPath(), toolsSection) } @@ -297,7 +301,7 @@ func (cb *ContextBuilder) BuildMessagesWithMemoryNamespace(history []providers.M }) if summary != "" { - systemPrompt += "\n\n## Summary of Previous Conversation\n\n" + summary + systemPrompt += "\n\n## Summary of Previous Conversation\nThis is a handoff summary from earlier compacted context. Treat it as background reference, not a new user instruction.\n\n" + summary } messages = append(messages, providers.Message{ diff --git a/pkg/agent/context_spec_test.go b/pkg/agent/context_spec_test.go index ba70899..70c9516 100644 --- a/pkg/agent/context_spec_test.go +++ b/pkg/agent/context_spec_test.go @@ -89,3 +89,18 @@ func TestShouldUseSpecCodingRequiresExplicitAndNonTrivialCodingIntent(t *testing } } } + +func TestBuildMessagesIncludesSessionRecallGuidanceAndSummaryHandoff(t *testing.T) { + cb := NewContextBuilder(t.TempDir(), nil) + msgs := cb.BuildMessagesWithMemoryNamespace(nil, "Key Facts\n- prior work", "继续昨天那个改动", nil, "cli", "direct", "", "main") + if len(msgs) == 0 { + t.Fatalf("expected system message") + } + content := msgs[0].Content + if !strings.Contains(content, "session_search") { + t.Fatalf("expected session_search guidance in system prompt, got:\n%s", content) + } + if !strings.Contains(content, "handoff summary") { + t.Fatalf("expected handoff summary note, got:\n%s", content) + } +} diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 7c12080..a619c5f 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -26,6 +26,7 @@ import ( "github.com/YspCoder/clawgo/pkg/bus" "github.com/YspCoder/clawgo/pkg/config" "github.com/YspCoder/clawgo/pkg/cron" + "github.com/YspCoder/clawgo/pkg/lifecycle" "github.com/YspCoder/clawgo/pkg/logger" "github.com/YspCoder/clawgo/pkg/providers" "github.com/YspCoder/clawgo/pkg/runtimecfg" @@ -35,45 +36,57 @@ import ( ) type AgentLoop struct { - bus *bus.MessageBus - cfg *config.Config - provider providers.LLMProvider - workspace string - model string - maxTokens int - temperature float64 - maxIterations int - sessions *session.SessionManager - contextBuilder *ContextBuilder - tools *tools.ToolRegistry - compactionEnabled bool - compactionTrigger int - compactionKeepRecent int - heartbeatAckMaxChars int - heartbeatAckToken string - audit *triggerAudit - running bool - sessionScheduler *SessionScheduler - providerChain []providerCandidate - providerNames []string - providerPool map[string]providers.LLMProvider - providerResponses map[string]config.ProviderResponsesConfig - providerMaxTokens map[string]int - providerTemperatures map[string]float64 - telegramStreaming bool - providerMu sync.RWMutex - sessionProvider map[string]string - streamMu sync.Mutex - sessionStreamed map[string]bool - subagentManager *tools.SubagentManager - subagentRouter *tools.SubagentRouter - configPath string - subagentDigestMu sync.Mutex - subagentDigestDelay time.Duration - subagentDigests map[string]*subagentDigestState - runMu sync.Mutex - runCancel context.CancelFunc - runWG sync.WaitGroup + bus *bus.MessageBus + cfg *config.Config + provider providers.LLMProvider + workspace string + model string + maxTokens int + temperature float64 + maxIterations int + sessions *session.SessionManager + contextBuilder *ContextBuilder + tools *tools.ToolRegistry + compactionEnabled bool + compactionTrigger int + compactionKeepRecent int + compactionProtectLastN int + compactionTargetRatio float64 + compactionPressureThreshold float64 + compactionMaxSummaryChars int + compactionMaxTranscriptChars int + heartbeatAckMaxChars int + heartbeatAckToken string + audit *triggerAudit + running bool + sessionScheduler *SessionScheduler + providerChain []providerCandidate + providerNames []string + providerPool map[string]providers.LLMProvider + providerResponses map[string]config.ProviderResponsesConfig + providerMaxTokens map[string]int + providerTemperatures map[string]float64 + telegramStreaming bool + providerMu sync.RWMutex + sessionProvider map[string]string + streamMu sync.Mutex + sessionStreamed map[string]bool + subagentManager *tools.SubagentManager + subagentRouter *tools.SubagentRouter + configPath string + subagentDigestMu sync.Mutex + subagentDigestDelay time.Duration + subagentDigests map[string]*subagentDigestState + compactionRunner *lifecycle.LoopRunner + compactionMu sync.Mutex + compactionSignal chan struct{} + compactionQueue []string + compactionQueued map[string]struct{} + compactionInflight map[string]struct{} + compactionDirty map[string]struct{} + runMu sync.Mutex + runCancel context.CancelFunc + runWG sync.WaitGroup } type providerCandidate struct { @@ -198,6 +211,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers return h }, )) + toolsRegistry.Register(tools.NewSessionSearchTool(sessionsManager)) // Register edit file tool editFileTool := tools.NewEditFileTool(workspace) @@ -221,35 +235,45 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers toolsRegistry.Register(tools.NewSystemInfoTool()) loop := &AgentLoop{ - bus: msgBus, - cfg: cfg, - provider: provider, - workspace: workspace, - model: provider.GetDefaultModel(), - maxTokens: cfg.Agents.Defaults.MaxTokens, - temperature: cfg.Agents.Defaults.Temperature, - maxIterations: cfg.Agents.Defaults.MaxToolIterations, - sessions: sessionsManager, - contextBuilder: NewContextBuilder(workspace, func() []string { return toolsRegistry.GetSummaries() }), - tools: toolsRegistry, - compactionEnabled: cfg.Agents.Defaults.ContextCompaction.Enabled, - compactionTrigger: cfg.Agents.Defaults.ContextCompaction.TriggerMessages, - compactionKeepRecent: cfg.Agents.Defaults.ContextCompaction.KeepRecentMessages, - heartbeatAckMaxChars: cfg.Agents.Defaults.Heartbeat.AckMaxChars, - heartbeatAckToken: loadHeartbeatAckToken(workspace), - audit: newTriggerAudit(workspace), - running: false, - sessionScheduler: NewSessionScheduler(0), - sessionProvider: map[string]string{}, - sessionStreamed: map[string]bool{}, - providerResponses: map[string]config.ProviderResponsesConfig{}, - providerMaxTokens: map[string]int{}, - providerTemperatures: map[string]float64{}, - telegramStreaming: cfg.Channels.Telegram.Streaming, - subagentManager: subagentManager, - subagentRouter: subagentRouter, - subagentDigestDelay: 5 * time.Second, - subagentDigests: map[string]*subagentDigestState{}, + bus: msgBus, + cfg: cfg, + provider: provider, + workspace: workspace, + model: provider.GetDefaultModel(), + maxTokens: cfg.Agents.Defaults.MaxTokens, + temperature: cfg.Agents.Defaults.Temperature, + maxIterations: cfg.Agents.Defaults.MaxToolIterations, + sessions: sessionsManager, + contextBuilder: NewContextBuilder(workspace, func() []string { return toolsRegistry.GetSummaries() }), + tools: toolsRegistry, + compactionEnabled: cfg.Agents.Defaults.ContextCompaction.Enabled, + compactionTrigger: cfg.Agents.Defaults.ContextCompaction.TriggerMessages, + compactionKeepRecent: cfg.Agents.Defaults.ContextCompaction.KeepRecentMessages, + compactionProtectLastN: normalizePositiveInt(cfg.Agents.Defaults.ContextCompaction.ProtectLastN, 12), + compactionTargetRatio: normalizeCompactionRatio(cfg.Agents.Defaults.ContextCompaction.TargetRatio, 0.35), + compactionPressureThreshold: normalizeCompactionRatio(cfg.Agents.Defaults.ContextCompaction.PressureThreshold, 0.8), + compactionMaxSummaryChars: normalizePositiveInt(cfg.Agents.Defaults.ContextCompaction.MaxSummaryChars, 6000), + compactionMaxTranscriptChars: normalizePositiveInt(cfg.Agents.Defaults.ContextCompaction.MaxTranscriptChars, 20000), + heartbeatAckMaxChars: cfg.Agents.Defaults.Heartbeat.AckMaxChars, + heartbeatAckToken: loadHeartbeatAckToken(workspace), + audit: newTriggerAudit(workspace), + running: false, + sessionScheduler: NewSessionScheduler(0), + sessionProvider: map[string]string{}, + sessionStreamed: map[string]bool{}, + providerResponses: map[string]config.ProviderResponsesConfig{}, + providerMaxTokens: map[string]int{}, + providerTemperatures: map[string]float64{}, + telegramStreaming: cfg.Channels.Telegram.Streaming, + subagentManager: subagentManager, + subagentRouter: subagentRouter, + subagentDigestDelay: 5 * time.Second, + subagentDigests: map[string]*subagentDigestState{}, + compactionRunner: lifecycle.NewLoopRunner(), + compactionSignal: make(chan struct{}, 1), + compactionQueued: map[string]struct{}{}, + compactionInflight: map[string]struct{}{}, + compactionDirty: map[string]struct{}{}, } if _, primaryModel := config.ParseProviderModelRef(cfg.Agents.Defaults.Model.Primary); strings.TrimSpace(primaryModel) != "" { loop.model = strings.TrimSpace(primaryModel) @@ -335,6 +359,9 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers if run == nil { return "", fmt.Errorf("subagent run is nil") } + if run.MaxToolIterations <= 0 { + run.MaxToolIterations = loop.maxIterations + } sessionKey := strings.TrimSpace(run.SessionKey) if sessionKey == "" { sessionKey = fmt.Sprintf("subagent:%s", strings.TrimSpace(run.ID)) @@ -457,6 +484,9 @@ func (al *AgentLoop) Stop() { } al.running = false al.runWG.Wait() + if al.compactionRunner != nil { + al.compactionRunner.Stop() + } } func (al *AgentLoop) buildSessionShards(ctx context.Context) []chan bus.InboundMessage { @@ -476,11 +506,12 @@ func (al *AgentLoop) buildSessionShards(ctx context.Context) []chan bus.InboundM return shards } -func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMessage, messages []providers.Message, toolDefs []providers.ToolDefinition, primaryErr error) (*providers.LLMResponse, string, error) { +func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMessage, messages []providers.Message, toolDefs []providers.ToolDefinition, primaryErr error) (*providers.LLMResponse, string, int, error) { if len(al.providerChain) <= 1 { - return nil, "", primaryErr + return nil, "", 0, primaryErr } lastErr := primaryErr + attempts := 0 candidateNames := make([]string, 0, len(al.providerChain)-1) for _, candidate := range al.providerChain[1:] { candidateNames = append(candidateNames, candidate.name) @@ -507,16 +538,17 @@ func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMe lastErr = err continue } + attempts++ fallbackOptions := al.buildResponsesOptionsForProvider(msg.SessionKey, candidate.name, int64(al.maxTokensForProvider(candidate.name)), al.temperatureForProvider(candidate.name)) resp, err := p.Chat(ctx, messages, toolDefs, candidateModel, fallbackOptions) if err == nil { al.setSessionProvider(msg.SessionKey, candidate.name) logger.WarnCF("agent", logger.C0150, map[string]interface{}{"provider": candidate.name, "model": candidateModel, "ref": candidate.ref}) - return resp, candidate.name, nil + return resp, candidate.name, attempts, nil } lastErr = err } - return nil, "", lastErr + return nil, "", attempts, lastErr } func (al *AgentLoop) ensureProviderCandidate(candidate providerCandidate) (providers.LLMProvider, string, error) { @@ -865,6 +897,9 @@ type llmTurnLoopResult struct { pendingPersist []providers.Message finalContent string iteration int + attemptCount int + restartCount int + failureCode string hasToolActivity bool } @@ -960,39 +995,150 @@ func (al *AgentLoop) executeResponseToolCalls(cfg llmTurnLoopConfig, iteration i return results } -func (al *AgentLoop) requestLLMResponse(cfg llmTurnLoopConfig, activeProvider providers.LLMProvider, activeModel string, messages []providers.Message, providerToolDefs []providers.ToolDefinition, options map[string]interface{}) (*providers.LLMResponse, error) { +func (al *AgentLoop) requestLLMResponse(cfg llmTurnLoopConfig, activeProvider providers.LLMProvider, activeModel string, messages []providers.Message, providerToolDefs []providers.ToolDefinition, options map[string]interface{}) (*providers.LLMResponse, int, error) { if cfg.enableStreaming { if sp, ok := activeProvider.(providers.StreamingLLMProvider); ok { - streamText := "" - lastPush := time.Now().Add(-time.Second) - return sp.ChatStream(cfg.ctx, messages, providerToolDefs, activeModel, options, func(delta string) { - if strings.TrimSpace(delta) == "" { - return - } - streamText += delta - if time.Since(lastPush) < 450*time.Millisecond { - return - } - if !shouldFlushTelegramStreamSnapshot(streamText) { - return - } - lastPush = time.Now() - replyID := "" - if cfg.triggerMsg.Metadata != nil { - replyID = cfg.triggerMsg.Metadata["message_id"] - } - al.bus.PublishOutbound(bus.OutboundMessage{ - Channel: cfg.toolChannel, - ChatID: cfg.toolChatID, - Content: streamText, - Action: "stream", - ReplyToID: replyID, - }) - al.markSessionStreamed(cfg.sessionKey) - }) + return al.requestStreamingLLMResponse(cfg, sp, activeProvider, activeModel, messages, providerToolDefs, options) } } - return activeProvider.Chat(cfg.ctx, messages, providerToolDefs, activeModel, options) + resp, err := activeProvider.Chat(cfg.ctx, messages, providerToolDefs, activeModel, options) + return resp, 1, err +} + +func (al *AgentLoop) requestStreamingLLMResponse(cfg llmTurnLoopConfig, streamer providers.StreamingLLMProvider, fallbackProvider providers.LLMProvider, activeModel string, messages []providers.Message, providerToolDefs []providers.ToolDefinition, options map[string]interface{}) (*providers.LLMResponse, int, error) { + streamText := "" + lastPush := time.Now().Add(-time.Second) + type streamResult struct { + resp *providers.LLMResponse + err error + } + streamCtx, cancel := context.WithCancel(cfg.ctx) + defer cancel() + resultCh := make(chan streamResult, 1) + var deltaMu sync.Mutex + firstDeltaSeen := false + lastDeltaAt := time.Now() + go func() { + resp, err := streamer.ChatStream(streamCtx, messages, providerToolDefs, activeModel, options, func(delta string) { + if strings.TrimSpace(delta) == "" { + return + } + deltaMu.Lock() + firstDeltaSeen = true + lastDeltaAt = time.Now() + deltaMu.Unlock() + streamText += delta + if time.Since(lastPush) < 450*time.Millisecond { + return + } + if !shouldFlushTelegramStreamSnapshot(streamText) { + return + } + lastPush = time.Now() + replyID := "" + if cfg.triggerMsg.Metadata != nil { + replyID = cfg.triggerMsg.Metadata["message_id"] + } + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: cfg.toolChannel, + ChatID: cfg.toolChatID, + Content: streamText, + Action: "stream", + ReplyToID: replyID, + }) + al.markSessionStreamed(cfg.sessionKey) + }) + resultCh <- streamResult{resp: resp, err: err} + }() + + firstDeltaTimeout, idleDeltaTimeout := streamMonitorTimeouts(cfg.ctx) + staleTriggered := false + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + for { + select { + case result := <-resultCh: + if result.err == nil { + return result.resp, 1, nil + } + deltaMu.Lock() + streamStarted := firstDeltaSeen + deltaMu.Unlock() + if staleTriggered && streamStarted { + return nil, 1, providers.NewProviderExecutionError("stream_stale", result.err.Error(), "stream", true, "") + } + if !streamStarted { + resp, err := fallbackProvider.Chat(cfg.ctx, messages, providerToolDefs, activeModel, options) + return resp, 2, err + } + return nil, 1, result.err + case <-ticker.C: + deltaMu.Lock() + streamStarted := firstDeltaSeen + idleFor := time.Since(lastDeltaAt) + deltaMu.Unlock() + timeout := firstDeltaTimeout + if streamStarted { + timeout = idleDeltaTimeout + } + if timeout > 0 && idleFor > timeout { + staleTriggered = true + cancel() + } + case <-cfg.ctx.Done(): + cancel() + return nil, 1, cfg.ctx.Err() + } + } +} + +func streamMonitorTimeouts(ctx context.Context) (time.Duration, time.Duration) { + first := 20 * time.Second + idle := 45 * time.Second + if deadline, ok := ctx.Deadline(); ok { + remaining := time.Until(deadline) + if remaining > 0 { + firstMin := 5 * time.Second + idleMin := 15 * time.Second + if remaining < firstMin { + firstMin = maxDuration(100*time.Millisecond, remaining/4) + } + if remaining < idleMin { + idleMin = maxDuration(250*time.Millisecond, remaining/3) + } + first = clampDuration(remaining/4, firstMin, first) + idle = clampDuration(remaining/3, idleMin, idle) + } + } + return first, idle +} + +func clampDuration(value, min, max time.Duration) time.Duration { + if value < min { + return min + } + if max > 0 && value > max { + return max + } + return value +} + +func maxDuration(a, b time.Duration) time.Duration { + if a > b { + return a + } + return b +} + +func (al *AgentLoop) maxIterationsForContext(ctx context.Context) int { + limit := al.maxIterations + if limit < 1 { + limit = 1 + } + if budget, ok := tools.SubagentIterationBudget(ctx); ok && budget > 0 { + return budget + } + return limit } func (al *AgentLoop) runLLMTurnLoop(cfg llmTurnLoopConfig) (llmTurnLoopResult, error) { @@ -1000,7 +1146,7 @@ func (al *AgentLoop) runLLMTurnLoop(cfg llmTurnLoopConfig) (llmTurnLoopResult, e messages: append([]providers.Message(nil), cfg.messages...), pendingPersist: make([]providers.Message, 0, 16), } - maxAllowed := al.maxIterations + maxAllowed := al.maxIterationsForContext(cfg.ctx) if maxAllowed < 1 { maxAllowed = 1 } @@ -1027,14 +1173,19 @@ func (al *AgentLoop) runLLMTurnLoop(cfg llmTurnLoopConfig) (llmTurnLoopResult, e logLLMTurnRequest(result.iteration, al.maxIterations, providerName, activeModel, result.messages, providerToolDefs, maxTokens, temperature) options := al.buildResponsesOptions(cfg.sessionKey, int64(maxTokens), temperature) - response, err := al.requestLLMResponse(cfg, activeProvider, activeModel, result.messages, providerToolDefs, options) + response, attempts, err := al.requestLLMResponse(cfg, activeProvider, activeModel, result.messages, providerToolDefs, options) + result.attemptCount += attempts if err != nil { - if fb, _, ferr := al.tryFallbackProviders(cfg.ctx, cfg.triggerMsg, result.messages, providerToolDefs, err); ferr == nil && fb != nil { + result.failureCode = classifyLLMFailureCode(err) + if fb, _, fallbackAttempts, ferr := al.tryFallbackProviders(cfg.ctx, cfg.triggerMsg, result.messages, providerToolDefs, err); ferr == nil && fb != nil { response = fb + result.attemptCount += fallbackAttempts err = nil } else { + result.attemptCount += fallbackAttempts err = ferr + result.failureCode = classifyLLMFailureCode(err) } } if err != nil { @@ -1059,9 +1210,6 @@ func (al *AgentLoop) runLLMTurnLoop(cfg llmTurnLoopConfig) (llmTurnLoopResult, e result.messages = append(result.messages, assistantMsg) result.pendingPersist = append(result.pendingPersist, assistantMsg) result.hasToolActivity = true - if maxAllowed < result.iteration+al.maxIterations { - maxAllowed = result.iteration + al.maxIterations - } for _, toolResultMsg := range al.executeResponseToolCalls(cfg, result.iteration, response) { result.messages = append(result.messages, toolResultMsg) @@ -1069,7 +1217,8 @@ func (al *AgentLoop) runLLMTurnLoop(cfg llmTurnLoopConfig) (llmTurnLoopResult, e } } - return result, nil + result.failureCode = "retry_limit" + return result, fmt.Errorf("max tool iterations exceeded (%d)", maxAllowed) } func (al *AgentLoop) logInboundMessageStart(msg bus.InboundMessage) { @@ -1083,7 +1232,7 @@ func (al *AgentLoop) logInboundMessageStart(msg bus.InboundMessage) { } func (al *AgentLoop) prepareUserMessageContext(msg bus.InboundMessage, memoryNamespace string) ([]providers.Message, string) { - history := al.sessions.GetHistory(msg.SessionKey) + history := al.sessions.GetPromptHistory(msg.SessionKey) summary := al.sessions.GetSummary(msg.SessionKey) al.sessions.AddMessage(msg.SessionKey, "user", msg.Content) if explicitPref := ExtractLanguagePreference(msg.Content); explicitPref != "" { @@ -1113,12 +1262,11 @@ func (al *AgentLoop) finalizeUserMessage(sessionKey, responseLang string, pendin Content: finalContent, }) al.sessions.SetLastLanguage(sessionKey, responseLang) - al.compactSessionIfNeeded(sessionKey) - _ = al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) + al.enqueueSessionCompaction(sessionKey) } func (al *AgentLoop) prepareSystemMessageContext(sessionKey string, msg bus.InboundMessage, originChannel, originChatID string) ([]providers.Message, string) { - history := al.sessions.GetHistory(sessionKey) + history := al.sessions.GetPromptHistory(sessionKey) summary := al.sessions.GetSummary(sessionKey) preferredLang, lastLang := al.sessions.GetLanguagePreferences(sessionKey) responseLang := DetectResponseLanguage(msg.Content, preferredLang, lastLang) @@ -1144,8 +1292,7 @@ func (al *AgentLoop) finalizeSystemMessage(sessionKey, responseLang string, msg Content: finalContent, }) al.sessions.SetLastLanguage(sessionKey, responseLang) - al.compactSessionIfNeeded(sessionKey) - _ = al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) + al.enqueueSessionCompaction(sessionKey) } func (al *AgentLoop) startSpecTaskForMessage(msg bus.InboundMessage) specCodingTaskRef { @@ -1372,6 +1519,14 @@ func (al *AgentLoop) GetSessionHistory(sessionKey string) []providers.Message { return al.sessions.GetHistory(sessionKey) } +func (al *AgentLoop) GetSessionHistoryWindow(sessionKey string, around, before, after, limit int) []providers.Message { + return al.sessions.GetHistoryWindow(sessionKey, around, before, after, limit) +} + +func (al *AgentLoop) SearchSessions(query string, kinds []string, excludeKey string, limit int) []session.SessionSearchResult { + return al.sessions.Search(query, kinds, excludeKey, limit) +} + func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { if msg.SessionKey == "" { msg.SessionKey = "main" @@ -1411,9 +1566,20 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) logDirectResponse: true, }) if err != nil { + tools.RecordSubagentExecutionStats(ctx, tools.SubagentExecutionStats{ + Iterations: loopResult.iteration, + Attempts: loopResult.attemptCount, + Restarts: loopResult.restartCount, + FailureCode: classifyLLMFailureCode(err), + }) al.reopenSpecTaskOnError(specTaskRef, msg, err) return "", err } + tools.RecordSubagentExecutionStats(ctx, tools.SubagentExecutionStats{ + Iterations: loopResult.iteration, + Attempts: loopResult.attemptCount, + Restarts: loopResult.restartCount, + }) finalContent, userContent := al.finalizeUserTurnResponse(ctx, msg, responseLang, loopResult) // Log response preview (original content) @@ -1940,32 +2106,215 @@ func isRemoteReference(ref string) bool { strings.HasPrefix(trimmed, "data:") } -// GetStartupInfo returns information about loaded tools and skills for logging. -func (al *AgentLoop) compactSessionIfNeeded(sessionKey string) { - if !al.compactionEnabled { +func (al *AgentLoop) ensureCompactionRunnerStarted() { + if al == nil || !al.compactionEnabled || al.compactionRunner == nil { return } + al.compactionRunner.Start(func(stop <-chan struct{}) { + al.runCompactionLoop(stop) + }) +} + +func (al *AgentLoop) signalCompactionWorker() { + if al == nil || al.compactionSignal == nil { + return + } + select { + case al.compactionSignal <- struct{}{}: + default: + } +} + +func (al *AgentLoop) enqueueSessionCompaction(sessionKey string) { + if al == nil || !al.compactionEnabled { + return + } + key := strings.TrimSpace(sessionKey) + if key == "" { + return + } + al.compactionMu.Lock() + if _, ok := al.compactionInflight[key]; ok { + al.compactionDirty[key] = struct{}{} + al.compactionMu.Unlock() + al.signalCompactionWorker() + return + } + if _, ok := al.compactionQueued[key]; ok { + al.compactionMu.Unlock() + al.signalCompactionWorker() + return + } + al.compactionQueue = append(al.compactionQueue, key) + al.compactionQueued[key] = struct{}{} + al.compactionMu.Unlock() + al.ensureCompactionRunnerStarted() + al.signalCompactionWorker() +} + +func (al *AgentLoop) dequeueSessionCompaction() (string, bool) { + if al == nil { + return "", false + } + al.compactionMu.Lock() + defer al.compactionMu.Unlock() + if len(al.compactionQueue) == 0 { + return "", false + } + key := al.compactionQueue[0] + al.compactionQueue = al.compactionQueue[1:] + delete(al.compactionQueued, key) + al.compactionInflight[key] = struct{}{} + return key, true +} + +func (al *AgentLoop) finishSessionCompaction(sessionKey string, retry bool) { + if al == nil { + return + } + key := strings.TrimSpace(sessionKey) + if key == "" { + return + } + shouldWake := false + al.compactionMu.Lock() + delete(al.compactionInflight, key) + if _, dirty := al.compactionDirty[key]; dirty { + delete(al.compactionDirty, key) + retry = true + } + if retry { + if _, queued := al.compactionQueued[key]; !queued { + al.compactionQueue = append(al.compactionQueue, key) + al.compactionQueued[key] = struct{}{} + shouldWake = true + } + } + al.compactionMu.Unlock() + if shouldWake { + al.signalCompactionWorker() + } +} + +func (al *AgentLoop) runCompactionLoop(stop <-chan struct{}) { + if al == nil { + return + } + for { + key, ok := al.dequeueSessionCompaction() + if !ok { + select { + case <-stop: + return + case <-al.compactionSignal: + continue + } + } + retry := al.runSessionCompactionJob(stop, key) + al.finishSessionCompaction(key, retry) + } +} + +func (al *AgentLoop) runSessionCompactionJob(stop <-chan struct{}, sessionKey string) bool { + if al == nil || strings.TrimSpace(sessionKey) == "" { + return false + } + jobCtx, cancel := al.newBackgroundCompactionContext(stop, sessionKey) + defer cancel() + applied, changed, needsRetry := al.compactSessionIfNeeded(jobCtx, sessionKey) + if changed && !applied { + return true + } + return needsRetry +} + +func (al *AgentLoop) newBackgroundCompactionContext(stop <-chan struct{}, sessionKey string) (context.Context, context.CancelFunc) { + base, baseCancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + select { + case <-stop: + baseCancel() + case <-done: + } + }() + timeout := al.compactionTimeoutForSession(sessionKey) + ctx, timeoutCancel := context.WithTimeout(base, timeout) + return ctx, func() { + close(done) + timeoutCancel() + baseCancel() + } +} + +func (al *AgentLoop) compactionTimeoutForSession(sessionKey string) time.Duration { + timeout := 30 * time.Second + if al != nil && al.cfg != nil { + name := strings.TrimSpace(al.getSessionProvider(sessionKey)) + if name == "" && len(al.providerNames) > 0 { + name = al.providerNames[0] + } + if pc, ok := config.ProviderConfigByName(al.cfg, name); ok && pc.TimeoutSec > 0 { + providerTimeout := time.Duration(pc.TimeoutSec) * time.Second + if providerTimeout < timeout { + timeout = providerTimeout + } + } + } + if timeout < 5*time.Second { + timeout = 5 * time.Second + } + return timeout +} + +func (al *AgentLoop) shouldCompactSnapshot(ctx context.Context, sessionKey string, snapshot session.SessionCompactionSnapshot) (bool, int, float64) { trigger := al.compactionTrigger if trigger <= 0 { trigger = 60 } - keepRecent := al.compactionKeepRecent - if keepRecent <= 0 || keepRecent >= trigger { - keepRecent = trigger / 2 - if keepRecent < 10 { - keepRecent = 10 - } + pressure := al.estimateCompactionPressure(ctx, sessionKey, snapshot.History, snapshot.Summary) + if len(snapshot.History) <= trigger && pressure < al.compactionPressureThreshold { + return false, trigger, pressure } - h := al.sessions.GetHistory(sessionKey) - if len(h) <= trigger { - return + return true, trigger, pressure +} + +// GetStartupInfo returns information about loaded tools and skills for logging. +func (al *AgentLoop) compactSessionIfNeeded(ctx context.Context, sessionKey string) (applied bool, changed bool, needsRetry bool) { + if !al.compactionEnabled { + return false, false, false } - removed := len(h) - keepRecent - tpl := "[runtime-compaction] removed %d old messages, kept %d recent messages" - note := fmt.Sprintf(tpl, removed, keepRecent) - if al.sessions.CompactSession(sessionKey, keepRecent, note) { - _ = al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) + if ctx == nil { + ctx = context.Background() } + snapshot := al.sessions.CompactionSnapshot(sessionKey) + if len(snapshot.History) == 0 { + return false, false, false + } + shouldCompact, trigger, _ := al.shouldCompactSnapshot(ctx, sessionKey, snapshot) + if !shouldCompact { + return false, false, false + } + keepRecent := al.compactionKeepCount(len(snapshot.History), trigger) + if keepRecent >= len(snapshot.History) { + return false, false, false + } + removed := len(snapshot.History) - keepRecent + older := append([]providers.Message(nil), snapshot.History[:removed]...) + recent := append([]providers.Message(nil), snapshot.History[removed:]...) + summary := al.buildCompactionSummary(ctx, sessionKey, older, snapshot.Summary) + if ctx.Err() != nil { + return false, false, true + } + if summary == "" { + summary = fmt.Sprintf("Key Facts\n- Runtime compaction removed %d older messages.\n\nDecisions\n\nOpen Items\n\nNext Steps", removed) + } + if al.sessions.ApplyCompactionIfUnchanged(sessionKey, snapshot.NextSeq, snapshot.Summary, recent, summary) { + return true, false, false + } + next := al.sessions.CompactionSnapshot(sessionKey) + shouldRetry, _, _ := al.shouldCompactSnapshot(context.Background(), sessionKey, next) + return false, true, shouldRetry } // RunStartupSelfCheckAllSessions runs startup compaction checks across loaded sessions. @@ -1975,35 +2324,16 @@ func (al *AgentLoop) RunStartupSelfCheckAllSessions(ctx context.Context) Startup return report } - trigger := al.compactionTrigger - if trigger <= 0 { - trigger = 60 - } - keepRecent := al.compactionKeepRecent - if keepRecent <= 0 || keepRecent >= trigger { - keepRecent = trigger / 2 - if keepRecent < 10 { - keepRecent = 10 - } - } - for _, key := range al.sessions.Keys() { select { case <-ctx.Done(): return report default: } - - history := al.sessions.GetHistory(key) - if len(history) <= trigger { - continue - } - - removed := len(history) - keepRecent - tpl := "[startup-compaction] removed %d old messages, kept %d recent messages" - note := fmt.Sprintf(tpl, removed, keepRecent) - if al.sessions.CompactSession(key, keepRecent, note) { - al.sessions.Save(al.sessions.GetOrCreate(key)) + jobCtx, cancel := context.WithTimeout(ctx, al.compactionTimeoutForSession(key)) + applied, _, _ := al.compactSessionIfNeeded(jobCtx, key) + cancel() + if applied { report.CompactedSessions++ } } @@ -2011,6 +2341,201 @@ func (al *AgentLoop) RunStartupSelfCheckAllSessions(ctx context.Context) Startup return report } +func (al *AgentLoop) buildCompactionSummary(ctx context.Context, sessionKey string, older []providers.Message, existingSummary string) string { + if al == nil || al.provider == nil || len(older) == 0 { + return strings.TrimSpace(existingSummary) + } + pruned := pruneCompactionMessages(older, al.compactionMaxTranscriptChars) + if compactor, ok := al.provider.(providers.ResponsesCompactor); ok && compactor.SupportsResponsesCompact() { + if summary, err := compactor.BuildSummaryViaResponsesCompact(ctx, al.model, existingSummary, pruned, al.compactionMaxSummaryChars); err == nil { + return strings.TrimSpace(summary) + } + } + + payload, err := json.Marshal(compactionTranscript(pruned)) + if err != nil { + return strings.TrimSpace(existingSummary) + } + resp, err := al.provider.Chat(ctx, []providers.Message{ + { + Role: "system", + Content: "You are compacting a previous conversation window for a future handoff. " + + "Return concise markdown with exactly these sections: Key Facts, Decisions, Open Items, Next Steps. " + + "Preserve concrete decisions, file paths, errors, and pending work. Do not address the user.", + }, + { + Role: "user", + Content: "Existing summary:\n" + strings.TrimSpace(existingSummary) + "\n\nEarlier conversation JSON:\n" + string(payload), + }, + }, nil, al.model, map[string]interface{}{"max_tokens": 900}) + if err != nil || resp == nil { + return strings.TrimSpace(existingSummary) + } + out := strings.TrimSpace(resp.Content) + if al.compactionMaxSummaryChars > 0 && len(out) > al.compactionMaxSummaryChars { + out = out[:al.compactionMaxSummaryChars] + } + return out +} + +func pruneCompactionMessages(messages []providers.Message, maxTranscriptChars int) []providers.Message { + out := make([]providers.Message, 0, len(messages)) + remaining := maxTranscriptChars + for _, msg := range messages { + cp := msg + if strings.EqualFold(strings.TrimSpace(cp.Role), "tool") && len(cp.Content) > 400 { + cp.Content = "[older tool output trimmed during context compaction]" + } else if len(cp.Content) > 2000 { + cp.Content = cp.Content[:2000] + "\n...[truncated for compaction]..." + } + if remaining > 0 && len(cp.Content) > remaining { + cp.Content = strings.TrimSpace(cp.Content[:remaining]) + "\n...[truncated for compaction budget]..." + } + if remaining > 0 { + remaining -= len(cp.Content) + if remaining <= 0 { + out = append(out, cp) + break + } + } + out = append(out, cp) + } + return out +} + +func compactionTranscript(messages []providers.Message) []map[string]interface{} { + out := make([]map[string]interface{}, 0, len(messages)) + for _, msg := range messages { + entry := map[string]interface{}{ + "role": msg.Role, + "content": strings.TrimSpace(msg.Content), + } + if strings.TrimSpace(msg.ToolCallID) != "" { + entry["tool_call_id"] = msg.ToolCallID + } + if len(msg.ToolCalls) > 0 { + entry["tool_calls"] = msg.ToolCalls + } + out = append(out, entry) + } + return out +} + +func (al *AgentLoop) compactionKeepCount(historyLen, trigger int) int { + if historyLen <= 0 { + return 0 + } + keep := al.compactionProtectLastN + if keep <= 0 { + keep = al.compactionKeepRecent + } + if keep <= 0 { + keep = 12 + } + if ratio := al.compactionTargetRatio; ratio > 0 { + ratioKeep := int(math.Ceil(float64(historyLen) * ratio)) + if ratioKeep > keep { + keep = ratioKeep + } + } + if trigger > 0 && keep >= trigger { + keep = maxLoopInt(trigger-1, al.compactionProtectLastN, 1) + } + if keep >= historyLen { + keep = historyLen - 1 + } + if keep < 1 { + keep = 1 + } + return keep +} + +func (al *AgentLoop) estimateCompactionPressure(ctx context.Context, sessionKey string, history []providers.Message, summary string) float64 { + maxTokens := al.maxTokensForSession(sessionKey) + if maxTokens <= 0 { + return 0 + } + estimateMessages := make([]providers.Message, 0, len(history)+1) + estimateMessages = append(estimateMessages, history...) + if strings.TrimSpace(summary) != "" { + estimateMessages = append([]providers.Message{{ + Role: "system", + Content: "Handoff summary:\n" + strings.TrimSpace(summary), + }}, estimateMessages...) + } + if activeProvider, activeModel, _, err := al.activeProviderForSession(sessionKey); err == nil { + if counter, ok := activeProvider.(providers.TokenCounter); ok { + if usage, err := counter.CountTokens(ctx, estimateMessages, nil, activeModel, nil); err == nil && usage != nil { + total := usage.TotalTokens + if total <= 0 { + total = usage.PromptTokens + } + if total > 0 { + return float64(total) / float64(maxTokens) + } + } + } + } + charCount := 0 + for _, msg := range estimateMessages { + charCount += len(msg.Content) + } + approxTokens := charCount / 4 + if approxTokens <= 0 && charCount > 0 { + approxTokens = 1 + } + return float64(approxTokens) / float64(maxTokens) +} + +func normalizePositiveInt(value, fallback int) int { + if value > 0 { + return value + } + return fallback +} + +func normalizeCompactionRatio(value, fallback float64) float64 { + if value > 0 && value <= 1 { + return value + } + return fallback +} + +func maxLoopInt(values ...int) int { + best := 0 + for _, value := range values { + if value > best { + best = value + } + } + return best +} + +func classifyLLMFailureCode(err error) string { + if err == nil { + return "" + } + if errors.Is(err, context.DeadlineExceeded) { + return "timeout" + } + if code := providers.ExecutionErrorCode(err); strings.TrimSpace(code) != "" { + return strings.TrimSpace(code) + } + lower := strings.ToLower(strings.TrimSpace(err.Error())) + switch { + case strings.Contains(lower, "max tool iterations"): + return "retry_limit" + case strings.Contains(lower, "stream stale"): + return "stream_stale" + case strings.Contains(lower, "stream failed"): + return "stream_failed" + case strings.Contains(lower, "thinking budget exhausted"), strings.Contains(lower, "continuation exhausted"): + return "continuation_exhausted" + default: + return "" + } +} + func (al *AgentLoop) GetStartupInfo() map[string]interface{} { info := make(map[string]interface{}) diff --git a/pkg/agent/loop_codex_options_test.go b/pkg/agent/loop_codex_options_test.go index 517f302..348daa7 100644 --- a/pkg/agent/loop_codex_options_test.go +++ b/pkg/agent/loop_codex_options_test.go @@ -193,10 +193,13 @@ func TestTryFallbackProvidersUsesFallbackProviderOptionsAndPersistsSelection(t * }, } - resp, providerName, err := loop.tryFallbackProviders(context.Background(), bus.InboundMessage{SessionKey: "chat-1"}, nil, nil, errors.New("primary failed")) + resp, providerName, attempts, err := loop.tryFallbackProviders(context.Background(), bus.InboundMessage{SessionKey: "chat-1"}, nil, nil, errors.New("primary failed")) if err != nil { t.Fatalf("expected fallback success, got %v", err) } + if attempts != 1 { + t.Fatalf("expected one fallback attempt, got %d", attempts) + } if resp == nil || resp.Content != "fallback" { t.Fatalf("unexpected fallback response: %#v", resp) } diff --git a/pkg/agent/reliability_test.go b/pkg/agent/reliability_test.go new file mode 100644 index 0000000..615ffdc --- /dev/null +++ b/pkg/agent/reliability_test.go @@ -0,0 +1,348 @@ +package agent + +import ( + "context" + "strings" + "sync" + "testing" + "time" + + "github.com/YspCoder/clawgo/pkg/bus" + "github.com/YspCoder/clawgo/pkg/lifecycle" + "github.com/YspCoder/clawgo/pkg/providers" + "github.com/YspCoder/clawgo/pkg/session" + toolspkg "github.com/YspCoder/clawgo/pkg/tools" +) + +type pressureProvider struct { + tokens int +} + +func (p *pressureProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + return &providers.LLMResponse{Content: "Key Facts\n- compacted", FinishReason: "stop"}, nil +} + +func (p *pressureProvider) GetDefaultModel() string { return "pressure-model" } + +func (p *pressureProvider) CountTokens(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.UsageInfo, error) { + return &providers.UsageInfo{PromptTokens: p.tokens, TotalTokens: p.tokens}, nil +} + +type fallbackStreamingProvider struct { + stream func(ctx context.Context, onDelta func(string)) (*providers.LLMResponse, error) + chat func(ctx context.Context) (*providers.LLMResponse, error) +} + +func (p *fallbackStreamingProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + return p.chat(ctx) +} + +func (p *fallbackStreamingProvider) ChatStream(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*providers.LLMResponse, error) { + return p.stream(ctx, onDelta) +} + +func (p *fallbackStreamingProvider) GetDefaultModel() string { return "stream-model" } + +type asyncCompactionProvider struct { + mu sync.Mutex + started chan int + release chan struct{} + finished chan int + calls int +} + +func (p *asyncCompactionProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + p.mu.Lock() + p.calls++ + call := p.calls + p.mu.Unlock() + if p.started != nil { + p.started <- call + } + if p.release != nil { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-p.release: + } + } + if p.finished != nil { + p.finished <- call + } + return &providers.LLMResponse{Content: "Key Facts\n- compacted", FinishReason: "stop"}, nil +} + +func (p *asyncCompactionProvider) GetDefaultModel() string { return "async-model" } + +func TestCompactSessionTriggeredByTokenPressure(t *testing.T) { + t.Parallel() + + sm := session.NewSessionManager(t.TempDir()) + key := "cli:pressure" + for _, content := range []string{"one", "two", "three", "four", "five", "six"} { + sm.AddMessage(key, "user", content) + } + provider := &pressureProvider{tokens: 900} + loop := &AgentLoop{ + provider: provider, + model: provider.GetDefaultModel(), + maxTokens: 1000, + providerNames: []string{"pressure"}, + sessions: sm, + compactionEnabled: true, + compactionTrigger: 100, + compactionProtectLastN: 2, + compactionKeepRecent: 2, + compactionTargetRatio: 0.35, + compactionPressureThreshold: 0.8, + compactionMaxSummaryChars: 6000, + compactionMaxTranscriptChars: 20000, + } + + applied, _, _ := loop.compactSessionIfNeeded(context.Background(), key) + if !applied { + t.Fatal("expected compaction to apply") + } + + history := sm.GetPromptHistory(key) + if len(history) != 3 { + t.Fatalf("expected ratio-based keep count 3, got %d", len(history)) + } + if history[0].Content != "four" || history[2].Content != "six" { + t.Fatalf("expected tail messages preserved, got %#v", history) + } + if summary := sm.GetSummary(key); summary == "" { + t.Fatal("expected compaction summary to be written") + } +} + +func TestFinalizeUserMessageDoesNotWaitForCompaction(t *testing.T) { + t.Parallel() + + sm := session.NewSessionManager(t.TempDir()) + key := "cli:async" + for _, content := range []string{"one", "two", "three", "four", "five", "six"} { + sm.AddMessage(key, "user", content) + } + provider := &asyncCompactionProvider{ + started: make(chan int, 2), + release: make(chan struct{}), + } + loop := &AgentLoop{ + provider: provider, + model: provider.GetDefaultModel(), + sessions: sm, + compactionEnabled: true, + compactionTrigger: 4, + compactionProtectLastN: 2, + compactionKeepRecent: 2, + compactionTargetRatio: 0.35, + compactionPressureThreshold: 0.1, + compactionMaxSummaryChars: 6000, + compactionMaxTranscriptChars: 20000, + compactionRunner: lifecycle.NewLoopRunner(), + compactionSignal: make(chan struct{}, 1), + compactionQueued: map[string]struct{}{}, + compactionInflight: map[string]struct{}{}, + compactionDirty: map[string]struct{}{}, + } + t.Cleanup(loop.Stop) + + start := time.Now() + loop.finalizeUserMessage(key, "en", nil, "final") + if elapsed := time.Since(start); elapsed > 150*time.Millisecond { + t.Fatalf("expected finalizeUserMessage to return quickly, took %s", elapsed) + } + select { + case <-provider.started: + case <-time.After(500 * time.Millisecond): + t.Fatal("expected async compaction to start in background") + } + close(provider.release) + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if summary := sm.GetSummary(key); summary != "" { + return + } + time.Sleep(20 * time.Millisecond) + } + t.Fatal("expected async compaction summary to be written") +} + +func TestCompactionWorkerRetriesDirtySessionWithoutLosingNewMessages(t *testing.T) { + t.Parallel() + + sm := session.NewSessionManager(t.TempDir()) + key := "cli:dirty" + for _, content := range []string{"one", "two", "three", "four", "five", "six"} { + sm.AddMessage(key, "user", content) + } + provider := &asyncCompactionProvider{ + started: make(chan int, 4), + release: make(chan struct{}, 4), + finished: make(chan int, 4), + } + loop := &AgentLoop{ + provider: provider, + model: provider.GetDefaultModel(), + sessions: sm, + compactionEnabled: true, + compactionTrigger: 4, + compactionProtectLastN: 2, + compactionKeepRecent: 2, + compactionTargetRatio: 0.35, + compactionPressureThreshold: 0.1, + compactionMaxSummaryChars: 6000, + compactionMaxTranscriptChars: 20000, + compactionRunner: lifecycle.NewLoopRunner(), + compactionSignal: make(chan struct{}, 1), + compactionQueued: map[string]struct{}{}, + compactionInflight: map[string]struct{}{}, + compactionDirty: map[string]struct{}{}, + } + t.Cleanup(loop.Stop) + + loop.enqueueSessionCompaction(key) + select { + case <-provider.started: + case <-time.After(time.Second): + t.Fatal("expected first compaction run to start") + } + + sm.AddMessage(key, "assistant", "seven") + loop.enqueueSessionCompaction(key) + provider.release <- struct{}{} + + select { + case <-provider.finished: + case <-time.After(time.Second): + t.Fatal("expected first compaction run to finish") + } + select { + case call := <-provider.started: + if call != 2 { + t.Fatalf("expected second compaction attempt after dirty retry, got call %d", call) + } + case <-time.After(time.Second): + t.Fatal("expected dirty session to trigger a second compaction run") + } + provider.release <- struct{}{} + select { + case <-provider.finished: + case <-time.After(time.Second): + t.Fatal("expected second compaction run to finish") + } + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + history := sm.GetPromptHistory(key) + if len(history) > 0 && history[len(history)-1].Content == "seven" && sm.GetSummary(key) != "" { + return + } + time.Sleep(20 * time.Millisecond) + } + t.Fatal("expected retried compaction to preserve new message and summary") +} + +func TestRequestStreamingLLMResponseFallsBackBeforeFirstDelta(t *testing.T) { + t.Parallel() + + provider := &fallbackStreamingProvider{ + stream: func(ctx context.Context, onDelta func(string)) (*providers.LLMResponse, error) { + <-ctx.Done() + return nil, providers.NewProviderExecutionError("stream_stale", "stream stale", "stream", true, "test") + }, + chat: func(ctx context.Context) (*providers.LLMResponse, error) { + return &providers.LLMResponse{Content: "fallback", FinishReason: "stop"}, nil + }, + } + loop := &AgentLoop{ + bus: bus.NewMessageBus(), + sessionStreamed: map[string]bool{}, + } + ctx, cancel := context.WithTimeout(context.Background(), 400*time.Millisecond) + defer cancel() + + resp, attempts, err := loop.requestStreamingLLMResponse(llmTurnLoopConfig{ + ctx: ctx, + sessionKey: "cli:test", + toolChannel: "telegram", + toolChatID: "chat", + enableStreaming: true, + }, provider, provider, provider.GetDefaultModel(), []providers.Message{{Role: "user", Content: "hello"}}, nil, nil) + if err != nil { + t.Fatalf("expected fallback success, got %v", err) + } + if attempts != 2 { + t.Fatalf("expected streaming + fallback attempts, got %d", attempts) + } + if resp == nil || resp.Content != "fallback" { + t.Fatalf("unexpected fallback response: %#v", resp) + } +} + +func TestRequestStreamingLLMResponseDoesNotFallbackAfterDelta(t *testing.T) { + t.Parallel() + + provider := &fallbackStreamingProvider{ + stream: func(ctx context.Context, onDelta func(string)) (*providers.LLMResponse, error) { + onDelta("partial") + return nil, providers.NewProviderExecutionError("stream_failed", "stream failed", "stream", true, "test") + }, + chat: func(ctx context.Context) (*providers.LLMResponse, error) { + return &providers.LLMResponse{Content: "fallback", FinishReason: "stop"}, nil + }, + } + loop := &AgentLoop{ + bus: bus.NewMessageBus(), + sessionStreamed: map[string]bool{}, + } + + resp, attempts, err := loop.requestStreamingLLMResponse(llmTurnLoopConfig{ + ctx: context.Background(), + sessionKey: "cli:test", + toolChannel: "telegram", + toolChatID: "chat", + enableStreaming: true, + }, provider, provider, provider.GetDefaultModel(), []providers.Message{{Role: "user", Content: "hello"}}, nil, nil) + if err == nil { + t.Fatal("expected stream failure without fallback") + } + if attempts != 1 { + t.Fatalf("expected single streaming attempt, got %d", attempts) + } + if resp != nil { + t.Fatalf("expected nil response on post-delta stream failure, got %#v", resp) + } +} + +func TestRunLLMTurnLoopReturnsRetryLimitError(t *testing.T) { + t.Parallel() + + provider := &sequenceProvider{ + responses: []*providers.LLMResponse{{ + Content: "", + ToolCalls: []providers.ToolCall{ + {ID: "tool-1", Name: "system_info", Arguments: map[string]interface{}{}}, + }, + FinishReason: "tool_calls", + }}, + } + loop := &AgentLoop{ + provider: provider, + model: provider.GetDefaultModel(), + maxIterations: 1, + tools: toolspkg.NewToolRegistry(), + providerNames: []string{"sequence"}, + sessionProvider: map[string]string{}, + } + loop.tools.Register(toolspkg.NewSystemInfoTool()) + _, err := loop.runLLMTurnLoop(llmTurnLoopConfig{ + ctx: context.Background(), + sessionKey: "cli:test", + messages: []providers.Message{{Role: "user", Content: "hello"}}, + }) + if err == nil || !strings.Contains(err.Error(), "max tool iterations exceeded") { + t.Fatalf("expected retry limit error, got %v", err) + } +} diff --git a/pkg/api/server.go b/pkg/api/server.go index b17f9a9..6d51269 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -47,7 +47,8 @@ type Server struct { workspacePath string logFilePath string onChat func(ctx context.Context, sessionKey, content string) (string, error) - onChatHistory func(sessionKey string) []map[string]interface{} + onChatHistory func(query ChatHistoryQuery) []map[string]interface{} + onSessionSearch func(query SessionSearchQuery) []map[string]interface{} onConfigAfter func(forceRuntimeReload bool) error onCron func(action string, args map[string]interface{}) (interface{}, error) onToolsCatalog func() interface{} @@ -70,6 +71,22 @@ type channelDraftStore struct { weixinRuntime *channels.WeixinChannel } +type ChatHistoryQuery struct { + Session string + Around int + Before int + After int + Limit int +} + +type SessionSearchQuery struct { + Query string + Limit int + Kinds []string + ExcludeCurrent bool + Session string +} + func NewServer(host string, port int, token string) *Server { addr := strings.TrimSpace(host) if addr == "" { @@ -94,9 +111,12 @@ func (s *Server) SetToken(token string) { s.token = strings.TrimSpace(tok func (s *Server) SetChatHandler(fn func(ctx context.Context, sessionKey, content string) (string, error)) { s.onChat = fn } -func (s *Server) SetChatHistoryHandler(fn func(sessionKey string) []map[string]interface{}) { +func (s *Server) SetChatHistoryHandler(fn func(query ChatHistoryQuery) []map[string]interface{}) { s.onChatHistory = fn } +func (s *Server) SetSessionSearchHandler(fn func(query SessionSearchQuery) []map[string]interface{}) { + s.onSessionSearch = fn +} func (s *Server) SetMessageBus(mb *bus.MessageBus) { s.messageBus = mb } func (s *Server) SetConfigAfterHook(fn func(forceRuntimeReload bool) error) { s.onConfigAfter = fn } func (s *Server) SetCronHandler(fn func(action string, args map[string]interface{}) (interface{}, error)) { @@ -611,6 +631,7 @@ func (s *Server) Start(ctx context.Context) error { mux.HandleFunc("/api/cron", s.handleWebUICron) mux.HandleFunc("/api/skills", s.handleWebUISkills) mux.HandleFunc("/api/sessions", s.handleWebUISessions) + mux.HandleFunc("/api/sessions/search", s.handleWebUISessionSearch) mux.HandleFunc("/api/memory", s.handleWebUIMemory) mux.HandleFunc("/api/workspace_file", s.handleWebUIWorkspaceFile) mux.HandleFunc("/api/workspace_docs", s.handleWebUIWorkspaceDocs) @@ -1516,7 +1537,14 @@ func (s *Server) handleWebUIChatHistory(w http.ResponseWriter, r *http.Request) writeJSON(w, map[string]interface{}{"ok": true, "session": session, "messages": []interface{}{}}) return } - writeJSON(w, map[string]interface{}{"ok": true, "session": session, "messages": s.onChatHistory(session)}) + query := ChatHistoryQuery{ + Session: session, + Around: queryBoundedPositiveInt(r, "around", 0, 1_000_000), + Before: queryBoundedPositiveInt(r, "before", 0, 1_000_000), + After: queryBoundedPositiveInt(r, "after", 0, 1_000_000), + Limit: queryBoundedPositiveInt(r, "limit", 200, 2000), + } + writeJSON(w, map[string]interface{}{"ok": true, "session": session, "messages": s.onChatHistory(query)}) } func (s *Server) handleWebUIChatLive(w http.ResponseWriter, r *http.Request) { @@ -3349,10 +3377,25 @@ func (s *Server) handleWebUISessions(w http.ResponseWriter, r *http.Request) { continue } name := e.Name() - if !strings.HasSuffix(name, ".jsonl") || strings.Contains(name, ".deleted.") { + key := "" + switch { + case strings.HasSuffix(name, ".meta.json"): + key = strings.TrimSuffix(name, ".meta.json") + case strings.HasSuffix(name, ".active.jsonl"): + key = strings.TrimSuffix(name, ".active.jsonl") + case strings.HasSuffix(name, ".jsonl") && !strings.Contains(name, ".deleted."): + key = strings.TrimSuffix(name, ".jsonl") + if strings.HasSuffix(key, ".active") { + key = strings.TrimSuffix(key, ".active") + } + if idx := strings.LastIndex(key, "."); idx > 0 { + if seqPart := key[idx+1:]; len(seqPart) == 4 && regexp.MustCompile(`^\d{4}$`).MatchString(seqPart) { + key = key[:idx] + } + } + default: continue } - key := strings.TrimSuffix(name, ".jsonl") if strings.TrimSpace(key) == "" { continue } @@ -3376,6 +3419,65 @@ func (s *Server) handleWebUISessions(w http.ResponseWriter, r *http.Request) { writeJSON(w, map[string]interface{}{"ok": true, "sessions": out}) } +func (s *Server) handleWebUISessionSearch(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if s.onSessionSearch == nil { + writeJSON(w, map[string]interface{}{"ok": true, "results": []interface{}{}}) + return + } + queryText := strings.TrimSpace(r.URL.Query().Get("query")) + if queryText == "" { + writeJSON(w, map[string]interface{}{"ok": true, "results": []interface{}{}}) + return + } + kinds := splitCSVQueryParam(r.URL.Query()["kinds"]) + if len(kinds) == 0 { + kinds = splitCSVQueryParam([]string{r.URL.Query().Get("kind")}) + } + excludeCurrent := false + if raw := strings.TrimSpace(r.URL.Query().Get("exclude_current")); raw != "" { + excludeCurrent = raw == "1" || strings.EqualFold(raw, "true") || strings.EqualFold(raw, "yes") + } + query := SessionSearchQuery{ + Query: queryText, + Limit: queryBoundedPositiveInt(r, "limit", 5, 100), + Kinds: kinds, + ExcludeCurrent: excludeCurrent, + Session: strings.TrimSpace(r.URL.Query().Get("session")), + } + writeJSON(w, map[string]interface{}{ + "ok": true, + "query": query.Query, + "results": s.onSessionSearch(query), + }) +} + +func splitCSVQueryParam(values []string) []string { + out := make([]string, 0, len(values)) + seen := map[string]struct{}{} + for _, value := range values { + for _, item := range strings.Split(value, ",") { + item = strings.TrimSpace(item) + if item == "" { + continue + } + if _, ok := seen[item]; ok { + continue + } + seen[item] = struct{}{} + out = append(out, item) + } + } + return out +} + func isUserFacingSessionKey(key string) bool { k := strings.ToLower(strings.TrimSpace(key)) if k == "" { diff --git a/pkg/api/server_test.go b/pkg/api/server_test.go index 3ac1f9e..4b4f6b2 100644 --- a/pkg/api/server_test.go +++ b/pkg/api/server_test.go @@ -408,6 +408,73 @@ func TestHandleWebUISessionsHidesInternalSessionsByDefault(t *testing.T) { } } +func TestHandleWebUIChatHistorySupportsWindowQuery(t *testing.T) { + t.Parallel() + + srv := NewServer("127.0.0.1", 0, "") + var got ChatHistoryQuery + srv.SetChatHistoryHandler(func(query ChatHistoryQuery) []map[string]interface{} { + got = query + return []map[string]interface{}{{"role": "assistant", "content": "ok"}} + }) + + req := httptest.NewRequest(http.MethodGet, "/api/chat/history?session=alpha&after=2&limit=3", nil) + rec := httptest.NewRecorder() + srv.handleWebUIChatHistory(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + if got.Session != "alpha" || got.After != 2 || got.Limit != 3 { + t.Fatalf("unexpected query: %+v", got) + } +} + +func TestHandleWebUISessionSearchReturnsResults(t *testing.T) { + t.Parallel() + + srv := NewServer("127.0.0.1", 0, "") + var got SessionSearchQuery + srv.SetSessionSearchHandler(func(query SessionSearchQuery) []map[string]interface{} { + got = query + return []map[string]interface{}{ + { + "key": "main", + "kind": "main", + "updated_at": int64(123), + "summary": "deploy notes", + "score": 2, + "snippets": []map[string]interface{}{{"seq": 3, "content": "deploy timeout"}}, + }, + } + }) + + req := httptest.NewRequest(http.MethodGet, "/api/sessions/search?query=deploy&kinds=main,cron&exclude_current=1&session=current&limit=7", nil) + rec := httptest.NewRecorder() + srv.handleWebUISessionSearch(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + if got.Query != "deploy" || got.Session != "current" || !got.ExcludeCurrent || got.Limit != 7 { + t.Fatalf("unexpected search query: %+v", got) + } + if len(got.Kinds) != 2 || got.Kinds[0] != "main" || got.Kinds[1] != "cron" { + t.Fatalf("unexpected kinds: %+v", got.Kinds) + } + + var payload struct { + OK bool `json:"ok"` + Results []map[string]interface{} `json:"results"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode payload: %v", err) + } + if !payload.OK || len(payload.Results) != 1 { + t.Fatalf("unexpected payload: %+v", payload) + } +} + func TestSaveProviderConfigForcesRuntimeReload(t *testing.T) { t.Parallel() diff --git a/pkg/config/config.go b/pkg/config/config.go index cd729af..61e26f4 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -106,15 +106,16 @@ type SubagentToolsConfig struct { } type SubagentRuntimeConfig struct { - Provider string `json:"provider,omitempty"` - Model string `json:"model,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TimeoutSec int `json:"timeout_sec,omitempty"` - MaxRetries int `json:"max_retries,omitempty"` - RetryBackoffMs int `json:"retry_backoff_ms,omitempty"` - MaxTaskChars int `json:"max_task_chars,omitempty"` - MaxResultChars int `json:"max_result_chars,omitempty"` - MaxParallelRuns int `json:"max_parallel_runs,omitempty"` + Provider string `json:"provider,omitempty"` + Model string `json:"model,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TimeoutSec int `json:"timeout_sec,omitempty"` + MaxRetries int `json:"max_retries,omitempty"` + RetryBackoffMs int `json:"retry_backoff_ms,omitempty"` + MaxTaskChars int `json:"max_task_chars,omitempty"` + MaxResultChars int `json:"max_result_chars,omitempty"` + MaxParallelRuns int `json:"max_parallel_runs,omitempty"` + MaxToolIterations int `json:"max_tool_iterations,omitempty"` } type AgentDefaults struct { @@ -158,12 +159,15 @@ type SystemSummaryPolicyConfig struct { } type ContextCompactionConfig struct { - Enabled bool `json:"enabled" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_ENABLED"` - Mode string `json:"mode" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_MODE"` - TriggerMessages int `json:"trigger_messages" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_TRIGGER_MESSAGES"` - KeepRecentMessages int `json:"keep_recent_messages" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_KEEP_RECENT_MESSAGES"` - MaxSummaryChars int `json:"max_summary_chars" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_MAX_SUMMARY_CHARS"` - MaxTranscriptChars int `json:"max_transcript_chars" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_MAX_TRANSCRIPT_CHARS"` + Enabled bool `json:"enabled" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_ENABLED"` + Mode string `json:"mode" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_MODE"` + TriggerMessages int `json:"trigger_messages" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_TRIGGER_MESSAGES"` + KeepRecentMessages int `json:"keep_recent_messages" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_KEEP_RECENT_MESSAGES"` + TargetRatio float64 `json:"target_ratio" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_TARGET_RATIO"` + ProtectLastN int `json:"protect_last_n" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_PROTECT_LAST_N"` + PressureThreshold float64 `json:"pressure_threshold" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_PRESSURE_THRESHOLD"` + MaxSummaryChars int `json:"max_summary_chars" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_MAX_SUMMARY_CHARS"` + MaxTranscriptChars int `json:"max_transcript_chars" env:"CLAWGO_AGENTS_DEFAULTS_CONTEXT_COMPACTION_MAX_TRANSCRIPT_CHARS"` } type ChannelsConfig struct { @@ -402,6 +406,9 @@ func DefaultConfig() *Config { Mode: "summary", TriggerMessages: 60, KeepRecentMessages: 20, + TargetRatio: 0.35, + ProtectLastN: 12, + PressureThreshold: 0.8, MaxSummaryChars: 6000, MaxTranscriptChars: 20000, }, diff --git a/pkg/config/validate.go b/pkg/config/validate.go index 1963e8d..ac2886f 100644 --- a/pkg/config/validate.go +++ b/pkg/config/validate.go @@ -76,6 +76,15 @@ func Validate(cfg *Config) []error { if cc.KeepRecentMessages <= 0 { errs = append(errs, fmt.Errorf("agents.defaults.context_compaction.keep_recent_messages must be > 0 when enabled=true")) } + if cc.TargetRatio <= 0 || cc.TargetRatio >= 1 { + errs = append(errs, fmt.Errorf("agents.defaults.context_compaction.target_ratio must be > 0 and < 1 when enabled=true")) + } + if cc.ProtectLastN <= 0 { + errs = append(errs, fmt.Errorf("agents.defaults.context_compaction.protect_last_n must be > 0 when enabled=true")) + } + if cc.PressureThreshold <= 0 || cc.PressureThreshold > 1 { + errs = append(errs, fmt.Errorf("agents.defaults.context_compaction.pressure_threshold must be > 0 and <= 1 when enabled=true")) + } if cc.TriggerMessages > 0 && cc.KeepRecentMessages >= cc.TriggerMessages { errs = append(errs, fmt.Errorf("agents.defaults.context_compaction.keep_recent_messages must be < trigger_messages")) } @@ -395,6 +404,9 @@ func validateSubagents(cfg *Config) []error { if raw.Runtime.MaxParallelRuns < 0 { errs = append(errs, fmt.Errorf("agents.subagents.%s.runtime.max_parallel_runs must be >= 0", id)) } + if raw.Runtime.MaxToolIterations < 0 { + errs = append(errs, fmt.Errorf("agents.subagents.%s.runtime.max_tool_iterations must be >= 0", id)) + } if raw.Tools.MaxParallelCalls < 0 { errs = append(errs, fmt.Errorf("agents.subagents.%s.tools.max_parallel_calls must be >= 0", id)) } diff --git a/pkg/jsonlog/store.go b/pkg/jsonlog/store.go new file mode 100644 index 0000000..7ac92ca --- /dev/null +++ b/pkg/jsonlog/store.go @@ -0,0 +1,96 @@ +package jsonlog + +import ( + "bufio" + "encoding/json" + "fmt" + "os" + "path/filepath" +) + +// AppendLine appends a single JSON-encoded line and returns the resulting file size. +func AppendLine(path string, value interface{}) (int64, error) { + data, err := json.Marshal(value) + if err != nil { + return 0, err + } + return AppendRawLine(path, data) +} + +// AppendRawLine appends a raw JSON line and returns the resulting file size. +func AppendRawLine(path string, line []byte) (int64, error) { + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return 0, err + } + f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) + if err != nil { + return 0, err + } + defer f.Close() + if _, err := f.Write(append(line, '\n')); err != nil { + return 0, err + } + st, err := f.Stat() + if err != nil { + return 0, err + } + return st.Size(), nil +} + +// Scan walks each non-empty JSONL line in order. +func Scan(path string, fn func(line []byte) error) error { + f, err := os.Open(path) + if err != nil { + return err + } + defer f.Close() + + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 0, 64*1024), 8*1024*1024) + for scanner.Scan() { + line := append([]byte(nil), scanner.Bytes()...) + if len(line) == 0 { + continue + } + if err := fn(line); err != nil { + return err + } + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("scan %s: %w", path, err) + } + return nil +} + +// FileSize returns 0 when the file does not exist. +func FileSize(path string) (int64, error) { + st, err := os.Stat(path) + if err != nil { + if os.IsNotExist(err) { + return 0, nil + } + return 0, err + } + return st.Size(), nil +} + +// ReadJSON reads a JSON sidecar file into dst. +func ReadJSON(path string, dst interface{}) error { + data, err := os.ReadFile(path) + if err != nil { + return err + } + return json.Unmarshal(data, dst) +} + +// WriteJSON writes a JSON sidecar file atomically enough for local runtime use. +func WriteJSON(path string, value interface{}) error { + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return err + } + data, err := json.MarshalIndent(value, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0644) +} diff --git a/pkg/providers/execution.go b/pkg/providers/execution.go index 92ae768..b290c07 100644 --- a/pkg/providers/execution.go +++ b/pkg/providers/execution.go @@ -4,8 +4,10 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "net/http" + "time" ) func newProviderExecutionError(code, message, stage string, retryable bool, source string) *ProviderExecutionError { @@ -18,6 +20,18 @@ func newProviderExecutionError(code, message, stage string, retryable bool, sour } } +func NewProviderExecutionError(code, message, stage string, retryable bool, source string) *ProviderExecutionError { + return newProviderExecutionError(code, message, stage, retryable, source) +} + +func ExecutionErrorCode(err error) string { + var execErr *ProviderExecutionError + if errors.As(err, &execErr) && execErr != nil { + return execErr.Code + } + return "" +} + func (p *HTTPProvider) executeJSONAttempts(ctx context.Context, endpoint string, payload interface{}, mutate func(*http.Request, authAttempt), classify func(int, []byte) (oauthFailureReason, bool)) (ProviderExecutionResult, error) { jsonData, err := json.Marshal(payload) if err != nil { @@ -42,12 +56,13 @@ func (p *HTTPProvider) executeJSONAttempts(ctx context.Context, endpoint string, } body, status, ctype, err := p.doJSONAttempt(req, attempt) if err != nil { + execErr := newProviderExecutionError("request_failed", err.Error(), "request", false, p.providerName) return ProviderExecutionResult{ StatusCode: status, ContentType: ctype, AttemptKind: attempt.kind, - Error: newProviderExecutionError("request_failed", err.Error(), "request", false, p.providerName), - }, err + Error: execErr, + }, execErr } reason, retry := classify(status, body) last = ProviderExecutionResult{ @@ -78,8 +93,10 @@ func (p *HTTPProvider) executeStreamAttempts(ctx context.Context, endpoint strin } var last ProviderExecutionResult for _, attempt := range attempts { - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData)) + attemptCtx, cancel := context.WithCancel(ctx) + req, err := http.NewRequestWithContext(attemptCtx, http.MethodPost, endpoint, bytes.NewReader(jsonData)) if err != nil { + cancel() return ProviderExecutionResult{Error: newProviderExecutionError("request_build_failed", err.Error(), "request", false, p.providerName)}, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") @@ -89,14 +106,21 @@ func (p *HTTPProvider) executeStreamAttempts(ctx context.Context, endpoint strin if mutate != nil { mutate(req, attempt) } - body, status, ctype, quotaHit, err := p.doStreamAttempt(req, attempt, onEvent) + streamOptions := streamAttemptTimeouts(ctx) + body, status, ctype, quotaHit, err := p.doStreamAttempt(req, attempt, onEvent, streamOptions, cancel) + cancel() if err != nil { + code := "stream_failed" + if streamOptions.staleTriggered { + code = "stream_stale" + } + execErr := newProviderExecutionError(code, err.Error(), "request", true, p.providerName) return ProviderExecutionResult{ StatusCode: status, ContentType: ctype, AttemptKind: attempt.kind, - Error: newProviderExecutionError("stream_failed", err.Error(), "request", false, p.providerName), - }, err + Error: execErr, + }, execErr } reason, _ := classifyOAuthFailure(status, body) last = ProviderExecutionResult{ @@ -115,3 +139,47 @@ func (p *HTTPProvider) executeStreamAttempts(ctx context.Context, endpoint strin } return last, nil } + +type streamAttemptOptions struct { + firstDeltaTimeout time.Duration + idleDeltaTimeout time.Duration + staleTriggered bool +} + +func streamAttemptTimeouts(ctx context.Context) *streamAttemptOptions { + opts := &streamAttemptOptions{ + firstDeltaTimeout: 20 * time.Second, + idleDeltaTimeout: 45 * time.Second, + } + if deadline, ok := ctx.Deadline(); ok { + remaining := time.Until(deadline) + if remaining > 0 { + first := remaining / 4 + idle := remaining / 3 + if remaining < 5*time.Second { + first = maxDuration(100*time.Millisecond, first) + } else if first < 5*time.Second { + first = 5 * time.Second + } + if remaining < 15*time.Second { + idle = maxDuration(250*time.Millisecond, idle) + } else if idle < 15*time.Second { + idle = 15 * time.Second + } + if first < opts.firstDeltaTimeout { + opts.firstDeltaTimeout = first + } + if idle < opts.idleDeltaTimeout { + opts.idleDeltaTimeout = idle + } + } + } + return opts +} + +func maxDuration(a, b time.Duration) time.Duration { + if a > b { + return a + } + return b +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 7da55ff..94c2d92 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -983,7 +983,7 @@ func (p *HTTPProvider) doJSONAttempt(req *http.Request, attempt authAttempt) ([] return body, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil } -func (p *HTTPProvider) doStreamAttempt(req *http.Request, attempt authAttempt, onEvent func(string)) ([]byte, int, string, bool, error) { +func (p *HTTPProvider) doStreamAttempt(req *http.Request, attempt authAttempt, onEvent func(string), options *streamAttemptOptions, cancel context.CancelFunc) ([]byte, int, string, bool, error) { client, err := p.httpClientForAttempt(attempt) if err != nil { return nil, 0, "", false, err @@ -1006,7 +1006,39 @@ func (p *HTTPProvider) doStreamAttempt(req *http.Request, attempt authAttempt, o scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) var dataLines []string var finalJSON []byte + lastActivity := time.Now() + firstDeltaSeen := false + done := make(chan struct{}) + if options != nil && cancel != nil { + go func() { + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + defer close(done) + for { + select { + case <-req.Context().Done(): + return + case <-ticker.C: + timeout := options.firstDeltaTimeout + if firstDeltaSeen { + timeout = options.idleDeltaTimeout + } + if timeout > 0 && time.Since(lastActivity) > timeout { + options.staleTriggered = true + cancel() + return + } + } + } + }() + } else { + close(done) + } + defer func() { + <-done + }() for scanner.Scan() { + lastActivity = time.Now() line := scanner.Text() if strings.TrimSpace(line) == "" { if len(dataLines) > 0 { @@ -1015,6 +1047,7 @@ func (p *HTTPProvider) doStreamAttempt(req *http.Request, attempt authAttempt, o if strings.TrimSpace(payload) == "[DONE]" { continue } + firstDeltaSeen = true if onEvent != nil { onEvent(payload) } @@ -1039,6 +1072,9 @@ func (p *HTTPProvider) doStreamAttempt(req *http.Request, attempt authAttempt, o } } if err := scanner.Err(); err != nil { + if options != nil && options.staleTriggered { + return nil, resp.StatusCode, ctype, false, fmt.Errorf("stream stale: %w", err) + } return nil, resp.StatusCode, ctype, false, fmt.Errorf("failed to read stream: %w", err) } if len(finalJSON) == 0 { diff --git a/pkg/providers/iflow_provider.go b/pkg/providers/iflow_provider.go index fd9de26..f19070e 100644 --- a/pkg/providers/iflow_provider.go +++ b/pkg/providers/iflow_provider.go @@ -254,7 +254,7 @@ func doIFlowStreamWithAttempts(ctx context.Context, base *HTTPProvider, payload onDelta(txt) } } - }) + }, nil, nil) if err != nil { return nil, 0, "", err } diff --git a/pkg/providers/oauth_test.go b/pkg/providers/oauth_test.go index 2856674..9ffc332 100644 --- a/pkg/providers/oauth_test.go +++ b/pkg/providers/oauth_test.go @@ -216,7 +216,7 @@ func TestHTTPProviderOpenAICompatStreamMergesLateUsage(t *testing.T) { if err != nil { t.Fatalf("new request failed: %v", err) } - body, status, _, _, err := provider.doStreamAttempt(req, authAttempt{kind: "api_key", token: "token"}, nil) + body, status, _, _, err := provider.doStreamAttempt(req, authAttempt{kind: "api_key", token: "token"}, nil, nil, nil) if err != nil { t.Fatalf("stream attempt failed: %v", err) } diff --git a/pkg/providers/openai_compat_provider.go b/pkg/providers/openai_compat_provider.go index d409c4d..f4a2e71 100644 --- a/pkg/providers/openai_compat_provider.go +++ b/pkg/providers/openai_compat_provider.go @@ -143,7 +143,7 @@ func doOpenAICompatStreamWithAttempts(ctx context.Context, base *HTTPProvider, p onDelta(txt) } } - }) + }, nil, nil) if err != nil { return nil, 0, "", err } diff --git a/pkg/providers/types.go b/pkg/providers/types.go index d07a7f7..3846538 100644 --- a/pkg/providers/types.go +++ b/pkg/providers/types.go @@ -101,6 +101,19 @@ type ProviderExecutionError struct { Source string `json:"source,omitempty"` } +func (e *ProviderExecutionError) Error() string { + if e == nil { + return "" + } + if e.Message != "" { + return e.Message + } + if e.Code != "" { + return e.Code + } + return "provider execution error" +} + type ProviderExecutionResult struct { Body []byte `json:"-"` StatusCode int `json:"status_code,omitempty"` diff --git a/pkg/session/manager.go b/pkg/session/manager.go index 430b0a6..500827b 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -1,21 +1,34 @@ package session import ( - "bufio" "crypto/sha1" "encoding/hex" "encoding/json" "fmt" "os" "path/filepath" + "regexp" + "sort" "strconv" "strings" "sync" "time" + "unicode" + "github.com/YspCoder/clawgo/pkg/jsonlog" "github.com/YspCoder/clawgo/pkg/providers" ) +const ( + defaultSessionSegmentMaxMessages = 200 + defaultSessionSegmentMaxBytes = 2 * 1024 * 1024 + defaultSessionPromptLoadSegments = 2 + maxTokenRefsPerSessionToken = 48 + maxSearchSnippetsPerSession = 3 +) + +var archiveSegmentRe = regexp.MustCompile(`^(?P.+)\.(?P\d{4})\.jsonl$`) + type Session struct { Key string `json:"key"` SessionID string `json:"session_id,omitempty"` @@ -28,12 +41,18 @@ type Session struct { Created time.Time `json:"created"` Updated time.Time `json:"updated"` mu sync.RWMutex + segments []sessionSegmentMeta + nextSeq int + index *sessionIndexFile } type SessionManager struct { - sessions map[string]*Session - mu sync.RWMutex - storage string + sessions map[string]*Session + mu sync.RWMutex + storage string + segmentMaxMessages int + segmentMaxBytes int64 + promptLoadSegments int } type openClawEvent struct { @@ -51,14 +70,100 @@ type openClawEvent struct { } `json:"message,omitempty"` } +type sessionSegmentMeta struct { + Name string `json:"name"` + Archived bool `json:"archived,omitempty"` + FirstSeq int `json:"first_seq,omitempty"` + LastSeq int `json:"last_seq,omitempty"` + MessageCount int `json:"message_count,omitempty"` + LastOffset int64 `json:"last_offset,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` +} + +type sessionMetaFile struct { + Version int `json:"version"` + SessionKey string `json:"session_key"` + SessionID string `json:"session_id,omitempty"` + Kind string `json:"kind,omitempty"` + Summary string `json:"summary,omitempty"` + CompactionCount int `json:"compaction_count,omitempty"` + LastLanguage string `json:"last_language,omitempty"` + PreferredLanguage string `json:"preferred_language,omitempty"` + MessageCount int `json:"message_count,omitempty"` + CreatedAt int64 `json:"created_at,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + NextSeq int `json:"next_seq,omitempty"` + Segments []sessionSegmentMeta `json:"segments,omitempty"` +} + +type sessionIndexRef struct { + Seq int `json:"seq"` + Role string `json:"role,omitempty"` + Segment string `json:"segment,omitempty"` + Snippet string `json:"snippet,omitempty"` +} + +type sessionIndexFile struct { + Version int `json:"version"` + SessionKey string `json:"session_key"` + LastSeq int `json:"last_seq,omitempty"` + LastOffset int64 `json:"last_offset,omitempty"` + Segment string `json:"segment,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + Tokens map[string][]sessionIndexRef `json:"tokens,omitempty"` +} + +type SessionSearchSnippet struct { + Seq int `json:"seq"` + Role string `json:"role,omitempty"` + Segment string `json:"segment,omitempty"` + Content string `json:"content,omitempty"` +} + +type SessionSearchResult struct { + Key string `json:"key"` + Kind string `json:"kind,omitempty"` + Summary string `json:"summary,omitempty"` + UpdatedAt time.Time `json:"updated_at"` + Score int `json:"score"` + Snippets []SessionSearchSnippet `json:"snippets,omitempty"` +} + +type SessionCompactionSnapshot struct { + Key string + History []providers.Message + Summary string + NextSeq int +} + +type sessionsIndexEntry struct { + SessionID string `json:"sessionId"` + SessionKey string `json:"sessionKey"` + UpdatedAt int64 `json:"updatedAt"` + Kind string `json:"kind"` + ChatType string `json:"chatType"` + CompactionCount int `json:"compactionCount"` + SessionFile string `json:"sessionFile,omitempty"` + Summary string `json:"summary,omitempty"` + LastLanguage string `json:"lastLanguage,omitempty"` + PreferredLanguage string `json:"preferredLanguage,omitempty"` +} + +type appendMessageResult struct { + refreshSessionsIndex bool +} + func NewSessionManager(storage string) *SessionManager { sm := &SessionManager{ - sessions: make(map[string]*Session), - storage: storage, + sessions: make(map[string]*Session), + storage: storage, + segmentMaxMessages: readPositiveIntEnv("CLAWGO_SESSION_SEGMENT_MAX_MESSAGES", defaultSessionSegmentMaxMessages), + segmentMaxBytes: int64(readPositiveIntEnv("CLAWGO_SESSION_SEGMENT_MAX_BYTES", defaultSessionSegmentMaxBytes)), + promptLoadSegments: readPositiveIntEnv("CLAWGO_SESSION_PROMPT_LOAD_SEGMENTS", defaultSessionPromptLoadSegments), } if storage != "" { - os.MkdirAll(storage, 0755) + _ = os.MkdirAll(storage, 0755) sm.cleanupArchivedSessions() sm.loadSessions() } @@ -67,133 +172,193 @@ func NewSessionManager(storage string) *SessionManager { } func (sm *SessionManager) GetOrCreate(key string) *Session { + session, _ := sm.getOrCreate(key) + return session +} + +func (sm *SessionManager) getOrCreate(key string) (*Session, bool) { sm.mu.RLock() session, ok := sm.sessions[key] sm.mu.RUnlock() - if ok { - return session + return session, false } sm.mu.Lock() defer sm.mu.Unlock() - - // Re-check existence after acquiring Write lock if session, ok = sm.sessions[key]; ok { - return session + return session, false } - + now := time.Now() session = &Session{ Key: key, SessionID: deriveSessionID(key), Kind: detectSessionKind(key), Messages: []providers.Message{}, - Created: time.Now(), - Updated: time.Now(), + Created: now, + Updated: now, + nextSeq: 1, } sm.sessions[key] = session - - return session + return session, true } func (sm *SessionManager) AddMessage(sessionKey, role, content string) { - sm.AddMessageFull(sessionKey, providers.Message{ - Role: role, - Content: content, - }) + sm.AddMessageFull(sessionKey, providers.Message{Role: role, Content: content}) } func (sm *SessionManager) AddMessageFull(sessionKey string, msg providers.Message) { - session := sm.GetOrCreate(sessionKey) + session, created := sm.getOrCreate(sessionKey) + persisted := false + refreshIndex := created session.mu.Lock() session.Messages = append(session.Messages, msg) session.Updated = time.Now() + if session.Created.IsZero() { + session.Created = session.Updated + } + appendResult, err := sm.appendMessageLocked(session, msg) + persisted = err == nil + refreshIndex = refreshIndex || appendResult.refreshSessionsIndex session.mu.Unlock() - - // Persist immediately (append-only). - sm.appendMessage(sessionKey, msg) + if persisted && refreshIndex { + _ = sm.writeOpenClawSessionsIndex() + } } -func (sm *SessionManager) appendMessage(sessionKey string, msg providers.Message) error { - if sm.storage == "" { +func (sm *SessionManager) GetPromptHistory(key string) []providers.Message { + sm.mu.RLock() + session, ok := sm.sessions[key] + sm.mu.RUnlock() + if !ok { return nil } - - sessionPath := filepath.Join(sm.storage, sessionKey+".jsonl") - f, err := os.OpenFile(sessionPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - return err - } - defer f.Close() - - event := toOpenClawMessageEvent(msg) - data, err := json.Marshal(event) - if err != nil { - return err - } - - if _, err = f.Write(append(data, '\n')); err != nil { - return err - } - return sm.writeOpenClawSessionsIndex() + session.mu.RLock() + defer session.mu.RUnlock() + history := make([]providers.Message, len(session.Messages)) + copy(history, session.Messages) + return history } func (sm *SessionManager) GetHistory(key string) []providers.Message { sm.mu.RLock() session, ok := sm.sessions[key] sm.mu.RUnlock() - if !ok { return []providers.Message{} } session.mu.RLock() - defer session.mu.RUnlock() + segments := append([]sessionSegmentMeta(nil), session.segments...) + loaded := make([]providers.Message, len(session.Messages)) + copy(loaded, session.Messages) + session.mu.RUnlock() - history := make([]providers.Message, len(session.Messages)) - copy(history, session.Messages) - return history + if len(segments) == 0 { + return loaded + } + all, err := sm.loadMessagesForSegments(segments) + if err != nil || len(all) == 0 { + return loaded + } + return all } func (sm *SessionManager) GetSummary(key string) string { sm.mu.RLock() session, ok := sm.sessions[key] sm.mu.RUnlock() - if !ok { return "" } - session.mu.RLock() defer session.mu.RUnlock() - return session.Summary } -func (sm *SessionManager) CompactSession(key string, keepLast int, note string) bool { - sm.mu.RLock() - session, ok := sm.sessions[key] - sm.mu.RUnlock() - if !ok { - return false - } - +func (sm *SessionManager) SetSummary(key, summary string) { + session := sm.GetOrCreate(key) session.mu.Lock() - defer session.mu.Unlock() - if keepLast <= 0 || len(session.Messages) <= keepLast { + trimmed := strings.TrimSpace(summary) + if session.Summary == trimmed { + session.mu.Unlock() + return + } + session.Summary = trimmed + session.Updated = time.Now() + _ = sm.persistSidecarsLocked(session) + session.mu.Unlock() + _ = sm.writeOpenClawSessionsIndex() +} + +func (sm *SessionManager) ApplyCompaction(key string, keep []providers.Message, summary string) bool { + session := sm.GetOrCreate(key) + session.mu.Lock() + applied := sm.applyCompactionLocked(session, keep, summary) + session.mu.Unlock() + if applied { + _ = sm.writeOpenClawSessionsIndex() + } + return applied +} + +func (sm *SessionManager) ApplyCompactionIfUnchanged(key string, baseNextSeq int, baseSummary string, keep []providers.Message, summary string) bool { + session := sm.GetOrCreate(key) + session.mu.Lock() + if session.nextSeq != baseNextSeq || session.Summary != baseSummary { + session.mu.Unlock() return false } - session.Messages = session.Messages[len(session.Messages)-keepLast:] + applied := sm.applyCompactionLocked(session, keep, summary) + session.mu.Unlock() + if applied { + _ = sm.writeOpenClawSessionsIndex() + } + return applied +} + +func (sm *SessionManager) applyCompactionLocked(session *Session, keep []providers.Message, summary string) bool { + if session == nil { + return false + } + if len(session.Messages) == 0 && len(keep) == 0 { + return false + } + session.Messages = append([]providers.Message(nil), keep...) + if trimmed := strings.TrimSpace(summary); trimmed != "" { + session.Summary = trimmed + } session.CompactionCount++ - if strings.TrimSpace(note) != "" { + session.Updated = time.Now() + if err := sm.persistSidecarsLocked(session); err != nil { + return false + } + return true +} + +func (sm *SessionManager) CompactSession(key string, keepLast int, note string) bool { + session := sm.GetOrCreate(key) + session.mu.Lock() + if keepLast <= 0 || len(session.Messages) <= keepLast { + session.mu.Unlock() + return false + } + session.Messages = append([]providers.Message(nil), session.Messages[len(session.Messages)-keepLast:]...) + session.CompactionCount++ + if trimmed := strings.TrimSpace(note); trimmed != "" { if strings.TrimSpace(session.Summary) == "" { - session.Summary = note + session.Summary = trimmed } else { - session.Summary += "\n" + note + session.Summary += "\n" + trimmed } } session.Updated = time.Now() + err := sm.persistSidecarsLocked(session) + session.mu.Unlock() + if err == nil { + _ = sm.writeOpenClawSessionsIndex() + } return true } @@ -204,21 +369,45 @@ func (sm *SessionManager) GetLanguagePreferences(key string) (preferred string, if !ok { return "", "" } - session.mu.RLock() defer session.mu.RUnlock() return session.PreferredLanguage, session.LastLanguage } +func (sm *SessionManager) CompactionSnapshot(key string) SessionCompactionSnapshot { + sm.mu.RLock() + session, ok := sm.sessions[key] + sm.mu.RUnlock() + if !ok || session == nil { + return SessionCompactionSnapshot{Key: key} + } + session.mu.RLock() + defer session.mu.RUnlock() + history := make([]providers.Message, len(session.Messages)) + copy(history, session.Messages) + return SessionCompactionSnapshot{ + Key: key, + History: history, + Summary: session.Summary, + NextSeq: session.nextSeq, + } +} + func (sm *SessionManager) SetLastLanguage(key, lang string) { if strings.TrimSpace(lang) == "" { return } session := sm.GetOrCreate(key) session.mu.Lock() + if session.LastLanguage == lang { + session.mu.Unlock() + return + } session.LastLanguage = lang session.Updated = time.Now() + _ = sm.persistSidecarsLocked(session) session.mu.Unlock() + _ = sm.writeOpenClawSessionsIndex() } func (sm *SessionManager) SetPreferredLanguage(key, lang string) { @@ -227,17 +416,27 @@ func (sm *SessionManager) SetPreferredLanguage(key, lang string) { } session := sm.GetOrCreate(key) session.mu.Lock() + if session.PreferredLanguage == lang { + session.mu.Unlock() + return + } session.PreferredLanguage = lang session.Updated = time.Now() + _ = sm.persistSidecarsLocked(session) session.mu.Unlock() + _ = sm.writeOpenClawSessionsIndex() } func (sm *SessionManager) Save(session *Session) error { - // Messages are persisted incrementally via AddMessageFull. - // Metadata is now centralized in sessions.json (OpenClaw-style index). - if sm.storage == "" { + if sm.storage == "" || session == nil { return nil } + session.mu.Lock() + if err := sm.persistSidecarsLocked(session); err != nil { + session.mu.Unlock() + return err + } + session.mu.Unlock() return sm.writeOpenClawSessionsIndex() } @@ -254,6 +453,7 @@ func (sm *SessionManager) Keys() []string { for k := range sm.sessions { keys = append(keys, k) } + sort.Strings(keys) return keys } @@ -276,12 +476,92 @@ func (sm *SessionManager) List(limit int) []Session { }) s.mu.RUnlock() } + sort.Slice(items, func(i, j int) bool { return items[i].Updated.After(items[j].Updated) }) if limit > 0 && len(items) > limit { - return items[:limit] + items = items[:limit] } return items } +func (sm *SessionManager) Search(query string, kinds []string, excludeKey string, limit int) []SessionSearchResult { + terms := tokenizeQueryText(query) + if len(terms) == 0 { + return nil + } + if limit <= 0 { + limit = 5 + } + kindSet := make(map[string]struct{}, len(kinds)) + for _, item := range kinds { + if v := strings.ToLower(strings.TrimSpace(item)); v != "" { + kindSet[v] = struct{}{} + } + } + + sm.mu.RLock() + keys := make([]string, 0, len(sm.sessions)) + for key := range sm.sessions { + keys = append(keys, key) + } + sm.mu.RUnlock() + sort.Strings(keys) + + results := make([]SessionSearchResult, 0, len(keys)) + for _, key := range keys { + if key == strings.TrimSpace(excludeKey) { + continue + } + sm.mu.RLock() + session := sm.sessions[key] + sm.mu.RUnlock() + if session == nil { + continue + } + session.mu.RLock() + kind := session.Kind + summary := session.Summary + updated := session.Updated + index := cloneSessionIndex(session.index) + session.mu.RUnlock() + if len(kindSet) > 0 { + if _, ok := kindSet[strings.ToLower(strings.TrimSpace(kind))]; !ok { + continue + } + } + + if index != nil { + if result, ok := searchSessionIndex(key, kind, summary, updated, terms, index); ok { + results = append(results, result) + continue + } + } + indexPath := sm.sessionIndexPath(key) + if index, err := sm.readIndexFile(indexPath); err == nil { + if result, ok := searchSessionIndex(key, kind, summary, updated, terms, index); ok { + results = append(results, result) + continue + } + } + if result, ok := sm.searchSessionByScan(key, kind, summary, updated, terms); ok { + results = append(results, result) + } + } + + sort.Slice(results, func(i, j int) bool { + if results[i].Score != results[j].Score { + return results[i].Score > results[j].Score + } + if !results[i].UpdatedAt.Equal(results[j].UpdatedAt) { + return results[i].UpdatedAt.After(results[j].UpdatedAt) + } + return results[i].Key < results[j].Key + }) + if len(results) > limit { + results = results[:limit] + } + return results +} + func toOpenClawMessageEvent(msg providers.Message) openClawEvent { role := strings.TrimSpace(strings.ToLower(msg.Role)) mappedRole := role @@ -346,7 +626,6 @@ func fromJSONLLine(line []byte) (providers.Message, bool) { func deriveSessionID(key string) string { sum := sha1.Sum([]byte("clawgo-session:" + key)) h := hex.EncodeToString(sum[:]) - // UUID-like deterministic id return h[0:8] + "-" + h[8:12] + "-" + h[12:16] + "-" + h[16:20] + "-" + h[20:32] } @@ -372,24 +651,24 @@ func (sm *SessionManager) writeOpenClawSessionsIndex() error { } sm.mu.RLock() defer sm.mu.RUnlock() - index := map[string]map[string]interface{}{} + index := map[string]sessionsIndexEntry{} for key, s := range sm.sessions { s.mu.RLock() - sessionFile := filepath.Join(sm.storage, key+".jsonl") sid := strings.TrimSpace(s.SessionID) if sid == "" { sid = deriveSessionID(key) } - entry := map[string]interface{}{ - "sessionId": sid, - "sessionKey": key, - "updatedAt": s.Updated.UnixMilli(), - "systemSent": true, - "abortedLastRun": false, - "compactionCount": s.CompactionCount, - "chatType": mapKindToChatType(s.Kind), - "sessionFile": sessionFile, - "kind": s.Kind, + entry := sessionsIndexEntry{ + SessionID: sid, + SessionKey: key, + UpdatedAt: s.Updated.UnixMilli(), + ChatType: mapKindToChatType(s.Kind), + Kind: s.Kind, + CompactionCount: s.CompactionCount, + SessionFile: filepath.Join(sm.storage, activeSegmentFilename(key)), + Summary: s.Summary, + LastLanguage: s.LastLanguage, + PreferredLanguage: s.PreferredLanguage, } s.mu.RUnlock() index[key] = entry @@ -398,10 +677,7 @@ func (sm *SessionManager) writeOpenClawSessionsIndex() error { if err != nil { return err } - if err := os.WriteFile(filepath.Join(sm.storage, "sessions.json"), data, 0644); err != nil { - return err - } - return nil + return os.WriteFile(filepath.Join(sm.storage, "sessions.json"), data, 0644) } func (sm *SessionManager) cleanupArchivedSessions() { @@ -450,71 +726,1023 @@ func mapKindToChatType(kind string) string { } func (sm *SessionManager) loadSessions() error { - // 1) Load sessions index first (sessions.json as source of truth) - indexPath := filepath.Join(sm.storage, "sessions.json") - if data, err := os.ReadFile(indexPath); err == nil { - var index map[string]struct { - SessionID string `json:"sessionId"` - SessionKey string `json:"sessionKey"` - UpdatedAt int64 `json:"updatedAt"` - Kind string `json:"kind"` - ChatType string `json:"chatType"` - CompactionCount int `json:"compactionCount"` - } - if err := json.Unmarshal(data, &index); err == nil { - for key, row := range index { - session := sm.GetOrCreate(key) - session.mu.Lock() - if strings.TrimSpace(row.SessionID) != "" { - session.SessionID = row.SessionID - } - if strings.TrimSpace(row.Kind) != "" { - session.Kind = row.Kind - } else if strings.TrimSpace(row.ChatType) == "direct" { - session.Kind = "main" - } - if row.UpdatedAt > 0 { - session.Updated = time.UnixMilli(row.UpdatedAt) - } - session.CompactionCount = row.CompactionCount - session.mu.Unlock() - } - } - } - - // 2) Load JSONL histories - files, err := os.ReadDir(sm.storage) + fallbackIndex := sm.readSessionsIndex() + keys, err := sm.discoverSessionKeys() if err != nil { return err } - for _, file := range files { - if file.IsDir() || filepath.Ext(file.Name()) != ".jsonl" { + for _, key := range keys { + session, loadErr := sm.loadSessionFromDisk(key, fallbackIndex[key]) + if loadErr != nil { + return loadErr + } + if session == nil { continue } - sessionKey := strings.TrimSuffix(file.Name(), ".jsonl") - session := sm.GetOrCreate(sessionKey) - - f, err := os.Open(filepath.Join(sm.storage, file.Name())) - if err != nil { - continue - } - scanner := bufio.NewScanner(f) - scanner.Buffer(make([]byte, 0, 64*1024), 8*1024*1024) - session.mu.Lock() - for scanner.Scan() { - if msg, ok := fromJSONLLine(scanner.Bytes()); ok { - session.Messages = append(session.Messages, msg) - } - } - session.mu.Unlock() - closeErr := f.Close() - if err := scanner.Err(); err != nil { - return fmt.Errorf("scan session file %s: %w", file.Name(), err) - } - if closeErr != nil { - return fmt.Errorf("close session file %s: %w", file.Name(), closeErr) - } + sm.sessions[key] = session } - return sm.writeOpenClawSessionsIndex() } + +func (sm *SessionManager) readSessionsIndex() map[string]sessionsIndexEntry { + if sm.storage == "" { + return nil + } + data, err := os.ReadFile(filepath.Join(sm.storage, "sessions.json")) + if err != nil { + return nil + } + var index map[string]sessionsIndexEntry + if err := json.Unmarshal(data, &index); err != nil { + return nil + } + return index +} + +func (sm *SessionManager) loadSessionFromDisk(key string, fallback sessionsIndexEntry) (*Session, error) { + meta, err := sm.loadOrRebuildMeta(key) + if err != nil { + return nil, err + } + index, indexErr := sm.readIndexFile(sm.sessionIndexPath(key)) + if indexErr != nil { + index = nil + } + if meta == nil { + if strings.TrimSpace(fallback.SessionKey) == "" && key == "" { + return nil, nil + } + now := time.UnixMilli(fallback.UpdatedAt) + if fallback.UpdatedAt <= 0 { + now = time.Now() + } + return &Session{ + Key: key, + SessionID: firstNonEmpty(fallback.SessionID, deriveSessionID(key)), + Kind: firstNonEmpty(fallback.Kind, detectSessionKind(key)), + Summary: fallback.Summary, + CompactionCount: fallback.CompactionCount, + LastLanguage: fallback.LastLanguage, + PreferredLanguage: fallback.PreferredLanguage, + Created: now, + Updated: now, + nextSeq: 1, + index: index, + }, nil + } + + workingMessages, err := sm.loadWorkingSet(meta.Segments) + if err != nil { + return nil, err + } + created := time.UnixMilli(meta.CreatedAt) + if meta.CreatedAt <= 0 { + created = time.Now() + } + updated := time.UnixMilli(meta.UpdatedAt) + if meta.UpdatedAt <= 0 { + updated = created + } + if strings.TrimSpace(meta.Summary) == "" && strings.TrimSpace(fallback.Summary) != "" { + meta.Summary = fallback.Summary + } + if strings.TrimSpace(meta.LastLanguage) == "" && strings.TrimSpace(fallback.LastLanguage) != "" { + meta.LastLanguage = fallback.LastLanguage + } + if strings.TrimSpace(meta.PreferredLanguage) == "" && strings.TrimSpace(fallback.PreferredLanguage) != "" { + meta.PreferredLanguage = fallback.PreferredLanguage + } + return &Session{ + Key: key, + SessionID: firstNonEmpty(meta.SessionID, fallback.SessionID, deriveSessionID(key)), + Kind: firstNonEmpty(meta.Kind, fallback.Kind, detectSessionKind(key)), + Messages: workingMessages, + Summary: firstNonEmpty(meta.Summary, fallback.Summary), + CompactionCount: maxInt(meta.CompactionCount, fallback.CompactionCount), + LastLanguage: firstNonEmpty(meta.LastLanguage, fallback.LastLanguage), + PreferredLanguage: firstNonEmpty(meta.PreferredLanguage, fallback.PreferredLanguage), + Created: created, + Updated: updated, + segments: append([]sessionSegmentMeta(nil), meta.Segments...), + nextSeq: maxInt(meta.NextSeq, meta.MessageCount+1, 1), + index: index, + }, nil +} + +func (sm *SessionManager) loadOrRebuildMeta(key string) (*sessionMetaFile, error) { + metaPath := sm.sessionMetaPath(key) + indexPath := sm.sessionIndexPath(key) + + var meta sessionMetaFile + metaValid := false + if err := jsonlog.ReadJSON(metaPath, &meta); err == nil && meta.Version > 0 { + if sm.metaMatchesStorage(key, &meta) { + metaValid = true + } + } + if metaValid { + if _, err := sm.readIndexFile(indexPath); err == nil { + return &meta, nil + } + } + rebuilt, err := sm.rebuildSidecars(key, &meta) + if err != nil { + return nil, err + } + return rebuilt, nil +} + +func (sm *SessionManager) metaMatchesStorage(key string, meta *sessionMetaFile) bool { + if meta == nil || len(meta.Segments) == 0 { + return false + } + for _, segment := range meta.Segments { + size, err := jsonlog.FileSize(filepath.Join(sm.storage, segment.Name)) + if err != nil { + return false + } + if strings.EqualFold(segment.Name, activeSegmentFilename(key)) && size != segment.LastOffset { + return false + } + } + return true +} + +func (sm *SessionManager) rebuildSidecars(key string, seed *sessionMetaFile) (*sessionMetaFile, error) { + segments, err := sm.discoverSessionSegments(key) + if err != nil { + return nil, err + } + if len(segments) == 0 { + return nil, nil + } + + meta := &sessionMetaFile{ + Version: 1, + SessionKey: key, + SessionID: deriveSessionID(key), + Kind: detectSessionKind(key), + Summary: strings.TrimSpace(seedValue(seed, func(v *sessionMetaFile) string { return v.Summary })), + CompactionCount: seedInt(seed, func(v *sessionMetaFile) int { return v.CompactionCount }), + LastLanguage: seedValue(seed, func(v *sessionMetaFile) string { return v.LastLanguage }), + PreferredLanguage: seedValue(seed, func(v *sessionMetaFile) string { return v.PreferredLanguage }), + Segments: make([]sessionSegmentMeta, 0, len(segments)), + } + index := sessionIndexFile{ + Version: 1, + SessionKey: key, + LastSeq: 0, + LastOffset: 0, + Segment: "", + UpdatedAt: time.Now().UnixMilli(), + Tokens: map[string][]sessionIndexRef{}, + } + + now := time.Now().UnixMilli() + seq := 0 + for _, name := range segments { + fullPath := filepath.Join(sm.storage, name) + size, err := jsonlog.FileSize(fullPath) + if err != nil { + return nil, err + } + seg := sessionSegmentMeta{ + Name: name, + Archived: !strings.EqualFold(name, activeSegmentFilename(key)), + LastOffset: size, + UpdatedAt: now, + } + if st, err := os.Stat(fullPath); err == nil { + seg.UpdatedAt = st.ModTime().UnixMilli() + if meta.CreatedAt == 0 || st.ModTime().UnixMilli() < meta.CreatedAt { + meta.CreatedAt = st.ModTime().UnixMilli() + } + if st.ModTime().UnixMilli() > meta.UpdatedAt { + meta.UpdatedAt = st.ModTime().UnixMilli() + } + } + if err := jsonlog.Scan(fullPath, func(line []byte) error { + msg, ok := fromJSONLLine(line) + if !ok { + return nil + } + seq++ + if seg.FirstSeq == 0 { + seg.FirstSeq = seq + } + seg.LastSeq = seq + seg.MessageCount++ + meta.MessageCount++ + appendTokens(index.Tokens, tokenizeIndexText(msg.Content), sessionIndexRef{ + Seq: seq, + Role: strings.ToLower(strings.TrimSpace(msg.Role)), + Segment: name, + Snippet: messageSnippet(msg.Content), + }) + return nil + }); err != nil { + return nil, err + } + meta.Segments = append(meta.Segments, seg) + } + if meta.CreatedAt == 0 { + meta.CreatedAt = now + } + if meta.UpdatedAt == 0 { + meta.UpdatedAt = now + } + meta.NextSeq = seq + 1 + if len(meta.Segments) > 0 { + last := meta.Segments[len(meta.Segments)-1] + index.LastSeq = last.LastSeq + index.LastOffset = last.LastOffset + index.Segment = last.Name + } + if err := sm.writeSidecarFiles(key, meta, &index); err != nil { + return nil, err + } + return meta, nil +} + +func (sm *SessionManager) loadWorkingSet(segments []sessionSegmentMeta) ([]providers.Message, error) { + if len(segments) == 0 { + return nil, nil + } + start := 0 + if sm.promptLoadSegments > 0 && len(segments) > sm.promptLoadSegments { + start = len(segments) - sm.promptLoadSegments + } + return sm.loadMessagesForSegments(segments[start:]) +} + +func (sm *SessionManager) loadMessagesForSegments(segments []sessionSegmentMeta) ([]providers.Message, error) { + out := make([]providers.Message, 0) + for _, segment := range segments { + path := filepath.Join(sm.storage, segment.Name) + if err := jsonlog.Scan(path, func(line []byte) error { + msg, ok := fromJSONLLine(line) + if ok { + out = append(out, msg) + } + return nil + }); err != nil { + if os.IsNotExist(err) { + continue + } + return nil, err + } + } + return out, nil +} + +func (sm *SessionManager) GetHistoryWindow(key string, around, before, after, limit int) []providers.Message { + sm.mu.RLock() + session, ok := sm.sessions[key] + sm.mu.RUnlock() + if !ok { + return nil + } + session.mu.RLock() + segments := append([]sessionSegmentMeta(nil), session.segments...) + session.mu.RUnlock() + if len(segments) == 0 { + return nil + } + startSeq, endSeq := computeHistorySeqWindow(segments, around, before, after, limit) + if endSeq < startSeq { + return nil + } + selected := make([]sessionSegmentMeta, 0, len(segments)) + for _, segment := range segments { + if segment.LastSeq < startSeq || segment.FirstSeq > endSeq { + continue + } + selected = append(selected, segment) + } + if len(selected) == 0 { + return nil + } + all, err := sm.loadMessagesForSegments(selected) + if err != nil { + return nil + } + out := make([]providers.Message, 0, len(all)) + seq := 0 + for _, segment := range selected { + for i := 0; i < segment.MessageCount && seq < len(all); i++ { + currentSeq := segment.FirstSeq + i + msg := all[seq] + seq++ + if currentSeq < startSeq || currentSeq > endSeq { + continue + } + out = append(out, msg) + } + } + return out +} + +func (sm *SessionManager) appendMessageLocked(session *Session, msg providers.Message) (appendMessageResult, error) { + if sm.storage == "" { + return appendMessageResult{}, nil + } + if err := sm.ensureActiveSegmentLocked(session); err != nil { + return appendMessageResult{}, err + } + result := appendMessageResult{} + if sm.shouldRolloverLocked(session) { + if err := sm.rolloverLocked(session); err != nil { + return appendMessageResult{}, err + } + result.refreshSessionsIndex = true + if err := sm.ensureActiveSegmentLocked(session); err != nil { + return appendMessageResult{}, err + } + } + active := sm.activeSegmentLocked(session) + if active == nil { + return appendMessageResult{}, fmt.Errorf("active session segment unavailable") + } + offset, err := jsonlog.AppendLine(filepath.Join(sm.storage, active.Name), toOpenClawMessageEvent(msg)) + if err != nil { + return appendMessageResult{}, err + } + if active.FirstSeq == 0 { + active.FirstSeq = session.nextSeq + } + active.LastSeq = session.nextSeq + active.MessageCount++ + active.LastOffset = offset + active.UpdatedAt = time.Now().UnixMilli() + sm.appendIndexLocked(session, msg, active.Name, active.LastOffset) + session.nextSeq++ + return result, sm.persistSidecarsLocked(session) +} + +func (sm *SessionManager) ensureActiveSegmentLocked(session *Session) error { + if session == nil { + return nil + } + for i := range session.segments { + if strings.EqualFold(session.segments[i].Name, activeSegmentFilename(session.Key)) { + return nil + } + } + session.segments = append(session.segments, sessionSegmentMeta{ + Name: activeSegmentFilename(session.Key), + Archived: false, + LastOffset: 0, + UpdatedAt: time.Now().UnixMilli(), + }) + return nil +} + +func (sm *SessionManager) activeSegmentLocked(session *Session) *sessionSegmentMeta { + if session == nil { + return nil + } + for i := range session.segments { + if strings.EqualFold(session.segments[i].Name, activeSegmentFilename(session.Key)) { + return &session.segments[i] + } + } + return nil +} + +func (sm *SessionManager) shouldRolloverLocked(session *Session) bool { + active := sm.activeSegmentLocked(session) + if active == nil { + return false + } + if sm.segmentMaxMessages > 0 && active.MessageCount >= sm.segmentMaxMessages { + return true + } + if sm.segmentMaxBytes > 0 && active.LastOffset >= sm.segmentMaxBytes { + return true + } + return false +} + +func (sm *SessionManager) rolloverLocked(session *Session) error { + active := sm.activeSegmentLocked(session) + if active == nil || active.MessageCount == 0 { + return nil + } + nextSeq := 1 + for _, segment := range session.segments { + if seq, ok := parseArchiveSegmentFilename(segment.Name, session.Key); ok && seq >= nextSeq { + nextSeq = seq + 1 + } + } + archiveName := archivedSegmentFilename(session.Key, nextSeq) + if err := os.Rename(filepath.Join(sm.storage, active.Name), filepath.Join(sm.storage, archiveName)); err != nil { + return err + } + active.Name = archiveName + active.Archived = true + if err := sm.ensureActiveSegmentLocked(session); err != nil { + return err + } + newActive := sm.activeSegmentLocked(session) + if newActive != nil { + newActive.Archived = false + newActive.FirstSeq = 0 + newActive.LastSeq = 0 + newActive.MessageCount = 0 + newActive.LastOffset = 0 + newActive.UpdatedAt = time.Now().UnixMilli() + } + rebuilt, err := sm.rebuildSidecars(session.Key, &sessionMetaFile{ + Summary: session.Summary, + CompactionCount: session.CompactionCount, + LastLanguage: session.LastLanguage, + PreferredLanguage: session.PreferredLanguage, + }) + if err != nil { + return err + } + if rebuilt != nil { + session.segments = append([]sessionSegmentMeta(nil), rebuilt.Segments...) + session.nextSeq = maxInt(rebuilt.NextSeq, session.nextSeq) + } + return nil +} + +func (sm *SessionManager) persistSidecarsLocked(session *Session) error { + if sm.storage == "" || session == nil { + return nil + } + meta := sessionMetaFile{ + Version: 1, + SessionKey: session.Key, + SessionID: firstNonEmpty(session.SessionID, deriveSessionID(session.Key)), + Kind: firstNonEmpty(session.Kind, detectSessionKind(session.Key)), + Summary: strings.TrimSpace(session.Summary), + CompactionCount: session.CompactionCount, + LastLanguage: strings.TrimSpace(session.LastLanguage), + PreferredLanguage: strings.TrimSpace(session.PreferredLanguage), + MessageCount: sessionMessageCount(session), + CreatedAt: session.Created.UnixMilli(), + UpdatedAt: session.Updated.UnixMilli(), + NextSeq: maxInt(session.nextSeq, sessionMessageCount(session)+1), + Segments: append([]sessionSegmentMeta(nil), session.segments...), + } + if session.index == nil { + index, err := sm.buildIndexForSessionLocked(session, meta.NextSeq-1) + if err != nil { + return err + } + session.index = index + } + return sm.writeSidecarFiles(session.Key, &meta, session.index) +} + +func (sm *SessionManager) buildIndexForSessionLocked(session *Session, seqEnd int) (*sessionIndexFile, error) { + messages, err := sm.loadMessagesForSegments(session.segments) + if err != nil { + return nil, err + } + index := &sessionIndexFile{ + Version: 1, + SessionKey: session.Key, + LastSeq: 0, + LastOffset: 0, + Segment: "", + UpdatedAt: time.Now().UnixMilli(), + Tokens: map[string][]sessionIndexRef{}, + } + for i, msg := range messages { + ref := sessionIndexRef{ + Seq: i + 1, + Role: strings.ToLower(strings.TrimSpace(msg.Role)), + Segment: segmentNameForSeq(session.segments, i+1), + Snippet: messageSnippet(msg.Content), + } + appendTokens(index.Tokens, tokenizeIndexText(msg.Content), ref) + } + index.LastSeq = seqEnd + if active := sm.activeSegmentLocked(session); active != nil { + index.LastOffset = active.LastOffset + index.Segment = active.Name + } + return index, nil +} + +func segmentNameForSeq(segments []sessionSegmentMeta, seq int) string { + for _, segment := range segments { + if seq >= segment.FirstSeq && seq <= segment.LastSeq { + return segment.Name + } + } + return "" +} + +func sessionMessageCount(session *Session) int { + count := 0 + for _, segment := range session.segments { + count += segment.MessageCount + } + if count == 0 && len(session.Messages) > 0 { + count = len(session.Messages) + } + return count +} + +func (sm *SessionManager) writeSidecarFiles(key string, meta *sessionMetaFile, index *sessionIndexFile) error { + if err := jsonlog.WriteJSON(sm.sessionMetaPath(key), meta); err != nil { + return err + } + return jsonlog.WriteJSON(sm.sessionIndexPath(key), index) +} + +func (sm *SessionManager) readIndexFile(path string) (*sessionIndexFile, error) { + var index sessionIndexFile + if err := jsonlog.ReadJSON(path, &index); err != nil { + return nil, err + } + if index.Version <= 0 { + return nil, fmt.Errorf("invalid index version") + } + return &index, nil +} + +func (sm *SessionManager) searchSessionByScan(key, kind, summary string, updated time.Time, terms []string) (SessionSearchResult, bool) { + session := SessionSearchResult{ + Key: key, + Kind: kind, + Summary: summary, + UpdatedAt: updated, + } + all, err := sm.loadMessagesForSegments(sm.sessionSegments(key)) + if err != nil { + return SessionSearchResult{}, false + } + type scored struct { + score int + snippet SessionSearchSnippet + } + matches := make([]scored, 0) + for idx, msg := range all { + score := messageMatchScore(msg, terms) + if score == 0 { + continue + } + matches = append(matches, scored{ + score: score, + snippet: SessionSearchSnippet{ + Seq: idx + 1, + Role: strings.ToLower(strings.TrimSpace(msg.Role)), + Content: messageSnippet(msg.Content), + }, + }) + } + if len(matches) == 0 { + return SessionSearchResult{}, false + } + sort.Slice(matches, func(i, j int) bool { + if matches[i].score != matches[j].score { + return matches[i].score > matches[j].score + } + return matches[i].snippet.Seq > matches[j].snippet.Seq + }) + for i, item := range matches { + if i >= maxSearchSnippetsPerSession { + break + } + session.Score += item.score + session.Snippets = append(session.Snippets, item.snippet) + } + return session, true +} + +func (sm *SessionManager) sessionSegments(key string) []sessionSegmentMeta { + sm.mu.RLock() + session := sm.sessions[key] + sm.mu.RUnlock() + if session == nil { + return nil + } + session.mu.RLock() + defer session.mu.RUnlock() + return append([]sessionSegmentMeta(nil), session.segments...) +} + +func searchSessionIndex(key, kind, summary string, updated time.Time, terms []string, index *sessionIndexFile) (SessionSearchResult, bool) { + type aggregate struct { + score int + ref sessionIndexRef + } + hits := map[int]*aggregate{} + for _, term := range terms { + for _, ref := range index.Tokens[term] { + item := hits[ref.Seq] + if item == nil { + item = &aggregate{ref: ref} + hits[ref.Seq] = item + } + item.score++ + } + } + if len(hits) == 0 { + return SessionSearchResult{}, false + } + aggs := make([]aggregate, 0, len(hits)) + for _, item := range hits { + aggs = append(aggs, *item) + } + sort.Slice(aggs, func(i, j int) bool { + if aggs[i].score != aggs[j].score { + return aggs[i].score > aggs[j].score + } + return aggs[i].ref.Seq > aggs[j].ref.Seq + }) + result := SessionSearchResult{ + Key: key, + Kind: kind, + Summary: summary, + UpdatedAt: updated, + } + for i, agg := range aggs { + if i >= maxSearchSnippetsPerSession { + break + } + result.Score += agg.score + result.Snippets = append(result.Snippets, SessionSearchSnippet{ + Seq: agg.ref.Seq, + Role: agg.ref.Role, + Segment: agg.ref.Segment, + Content: agg.ref.Snippet, + }) + } + return result, true +} + +func appendTokens(dst map[string][]sessionIndexRef, tokens []string, ref sessionIndexRef) { + for _, token := range tokens { + items := dst[token] + if len(items) >= maxTokenRefsPerSessionToken { + continue + } + items = append(items, ref) + dst[token] = items + } +} + +func (sm *SessionManager) appendIndexLocked(session *Session, msg providers.Message, segment string, offset int64) { + if session == nil { + return + } + if session.index == nil { + session.index = &sessionIndexFile{ + Version: 1, + SessionKey: session.Key, + Tokens: map[string][]sessionIndexRef{}, + } + } + if session.index.Tokens == nil { + session.index.Tokens = map[string][]sessionIndexRef{} + } + ref := sessionIndexRef{ + Seq: session.nextSeq, + Role: strings.ToLower(strings.TrimSpace(msg.Role)), + Segment: segment, + Snippet: messageSnippet(msg.Content), + } + appendTokens(session.index.Tokens, tokenizeIndexText(msg.Content), ref) + session.index.LastSeq = session.nextSeq + session.index.LastOffset = offset + session.index.Segment = segment + session.index.UpdatedAt = time.Now().UnixMilli() +} + +func messageMatchScore(msg providers.Message, terms []string) int { + if len(terms) == 0 { + return 0 + } + content := strings.ToLower(msg.Content) + score := 0 + for _, term := range terms { + if strings.Contains(content, term) { + score++ + } + } + return score +} + +func messageSnippet(content string) string { + content = strings.TrimSpace(strings.ReplaceAll(content, "\n", " ")) + if len(content) <= 180 { + return content + } + return content[:180] + "..." +} + +func tokenizeIndexText(text string) []string { + return tokenizeSearchText(text, false) +} + +func tokenizeQueryText(text string) []string { + return tokenizeSearchText(text, true) +} + +func tokenizeSearchText(text string, includeSingleHan bool) []string { + text = strings.ToLower(strings.TrimSpace(text)) + if text == "" { + return nil + } + out := make([]string, 0, 16) + seen := map[string]struct{}{} + var asciiBuf strings.Builder + flushASCII := func() { + if asciiBuf.Len() == 0 { + return + } + fields := strings.FieldsFunc(asciiBuf.String(), func(r rune) bool { + return !unicode.IsLetter(r) && !unicode.IsNumber(r) + }) + for _, field := range fields { + field = strings.TrimSpace(field) + if len(field) < 2 { + continue + } + if _, ok := seen[field]; ok { + continue + } + seen[field] = struct{}{} + out = append(out, field) + } + asciiBuf.Reset() + } + var hanRunes []rune + flushHan := func() { + if len(hanRunes) == 0 { + return + } + if includeSingleHan && len(hanRunes) == 1 { + token := string(hanRunes[0]) + if _, ok := seen[token]; !ok { + seen[token] = struct{}{} + out = append(out, token) + } + } + if len(hanRunes) >= 2 { + for i := 0; i < len(hanRunes)-1; i++ { + token := string(hanRunes[i : i+2]) + if _, ok := seen[token]; ok { + continue + } + seen[token] = struct{}{} + out = append(out, token) + } + } + hanRunes = hanRunes[:0] + } + for _, r := range text { + switch { + case unicode.Is(unicode.Han, r): + flushASCII() + hanRunes = append(hanRunes, r) + case unicode.IsLetter(r) || unicode.IsNumber(r): + flushHan() + asciiBuf.WriteRune(r) + default: + flushASCII() + flushHan() + } + } + flushASCII() + flushHan() + return out +} + +func (sm *SessionManager) discoverSessionKeys() ([]string, error) { + if sm.storage == "" { + return nil, nil + } + entries, err := os.ReadDir(sm.storage) + if err != nil { + return nil, err + } + keys := map[string]struct{}{} + for _, entry := range entries { + if entry.IsDir() { + continue + } + if key, ok := sessionKeyFromFilename(entry.Name()); ok { + keys[key] = struct{}{} + } + } + out := make([]string, 0, len(keys)) + for key := range keys { + out = append(out, key) + } + sort.Strings(out) + return out, nil +} + +func (sm *SessionManager) discoverSessionSegments(key string) ([]string, error) { + entries, err := os.ReadDir(sm.storage) + if err != nil { + return nil, err + } + names := make([]string, 0) + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + switch { + case name == activeSegmentFilename(key): + names = append(names, name) + case name == legacySegmentFilename(key): + names = append(names, name) + default: + if _, ok := parseArchiveSegmentFilename(name, key); ok { + names = append(names, name) + } + } + } + sort.Slice(names, func(i, j int) bool { + return compareSegmentNames(key, names[i], names[j]) + }) + return names, nil +} + +func compareSegmentNames(key, left, right string) bool { + if left == right { + return false + } + if left == legacySegmentFilename(key) { + return true + } + if right == legacySegmentFilename(key) { + return false + } + if left == activeSegmentFilename(key) { + return false + } + if right == activeSegmentFilename(key) { + return true + } + li, _ := parseArchiveSegmentFilename(left, key) + ri, _ := parseArchiveSegmentFilename(right, key) + return li < ri +} + +func sessionKeyFromFilename(name string) (string, bool) { + switch { + case strings.HasSuffix(name, ".meta.json"): + return strings.TrimSuffix(name, ".meta.json"), true + case strings.HasSuffix(name, ".index.json"): + return strings.TrimSuffix(name, ".index.json"), true + case strings.HasSuffix(name, ".active.jsonl"): + return strings.TrimSuffix(name, ".active.jsonl"), true + case strings.HasSuffix(name, ".jsonl") && !strings.Contains(name, ".deleted."): + if m := archiveSegmentRe.FindStringSubmatch(name); len(m) == 3 { + return m[1], true + } + return strings.TrimSuffix(name, ".jsonl"), true + default: + return "", false + } +} + +func parseArchiveSegmentFilename(name, key string) (int, bool) { + m := archiveSegmentRe.FindStringSubmatch(name) + if len(m) != 3 || m[1] != key { + return 0, false + } + n, err := strconv.Atoi(m[2]) + if err != nil { + return 0, false + } + return n, true +} + +func activeSegmentFilename(key string) string { return key + ".active.jsonl" } +func legacySegmentFilename(key string) string { return key + ".jsonl" } +func archivedSegmentFilename(key string, seq int) string { + return fmt.Sprintf("%s.%04d.jsonl", key, seq) +} + +func (sm *SessionManager) sessionMetaPath(key string) string { + return filepath.Join(sm.storage, key+".meta.json") +} +func (sm *SessionManager) sessionIndexPath(key string) string { + return filepath.Join(sm.storage, key+".index.json") +} + +func readPositiveIntEnv(name string, fallback int) int { + if value := strings.TrimSpace(os.Getenv(name)); value != "" { + if n, err := strconv.Atoi(value); err == nil && n > 0 { + return n + } + } + return fallback +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return value + } + } + return "" +} + +func seedValue(seed *sessionMetaFile, get func(*sessionMetaFile) string) string { + if seed == nil { + return "" + } + return get(seed) +} + +func seedInt(seed *sessionMetaFile, get func(*sessionMetaFile) int) int { + if seed == nil { + return 0 + } + return get(seed) +} + +func maxInt(values ...int) int { + best := 0 + for _, value := range values { + if value > best { + best = value + } + } + return best +} + +func cloneSessionIndex(index *sessionIndexFile) *sessionIndexFile { + if index == nil { + return nil + } + out := &sessionIndexFile{ + Version: index.Version, + SessionKey: index.SessionKey, + LastSeq: index.LastSeq, + LastOffset: index.LastOffset, + Segment: index.Segment, + UpdatedAt: index.UpdatedAt, + Tokens: make(map[string][]sessionIndexRef, len(index.Tokens)), + } + for key, refs := range index.Tokens { + out.Tokens[key] = append([]sessionIndexRef(nil), refs...) + } + return out +} + +func computeHistorySeqWindow(segments []sessionSegmentMeta, around, before, after, limit int) (int, int) { + total := 0 + for _, segment := range segments { + if segment.LastSeq > total { + total = segment.LastSeq + } + } + if total <= 0 { + return 1, 0 + } + if limit <= 0 { + limit = 50 + } + start := 1 + end := total + if around > 0 { + half := limit / 2 + if half < 1 { + half = 1 + } + start = around - half + end = around + half + if start < 1 { + start = 1 + } + if end > total { + end = total + } + } else { + if after > 0 { + start = after + 1 + } + if before > 0 { + end = before - 1 + } + } + if start < 1 { + start = 1 + } + if end > total { + end = total + } + if end < start { + return 1, 0 + } + if end-start+1 > limit { + start = end - limit + 1 + if start < 1 { + start = 1 + } + } + return start, end +} diff --git a/pkg/session/manager_test.go b/pkg/session/manager_test.go index a28af3d..2fc2b88 100644 --- a/pkg/session/manager_test.go +++ b/pkg/session/manager_test.go @@ -5,6 +5,10 @@ import ( "path/filepath" "strings" "testing" + "time" + + "github.com/YspCoder/clawgo/pkg/jsonlog" + "github.com/YspCoder/clawgo/pkg/providers" ) func TestLoadSessionsReturnsScannerErrorForOversizedLine(t *testing.T) { @@ -38,3 +42,197 @@ func TestFromJSONLLineParsesOpenClawToolResult(t *testing.T) { } } +func TestSessionManagerWritesSidecarsAndSearches(t *testing.T) { + t.Parallel() + + storage := t.TempDir() + sm := NewSessionManager(storage) + key := "cli:default" + + sm.AddMessage(key, "user", "deploy project alpha") + sm.AddMessage(key, "assistant", "deployment failed with timeout after contacting api gateway") + + for _, name := range []string{ + key + ".active.jsonl", + key + ".meta.json", + key + ".index.json", + } { + if _, err := os.Stat(filepath.Join(storage, name)); err != nil { + t.Fatalf("expected artifact %s: %v", name, err) + } + } + + results := sm.Search("deploy timeout", nil, "", 5) + if len(results) != 1 { + t.Fatalf("expected one search result, got %#v", results) + } + if results[0].Key != key || len(results[0].Snippets) == 0 { + t.Fatalf("unexpected search result: %#v", results[0]) + } +} + +func TestSessionManagerRebuildsMissingSidecarsFromJSONL(t *testing.T) { + t.Parallel() + + storage := t.TempDir() + sm := NewSessionManager(storage) + key := "cli:summary" + + sm.AddMessage(key, "user", "remember previous deploy steps") + sm.SetSummary(key, "Key Facts\n- Previous deploy steps were discussed.") + + if err := os.Remove(filepath.Join(storage, key+".meta.json")); err != nil { + t.Fatalf("remove meta: %v", err) + } + if err := os.Remove(filepath.Join(storage, key+".index.json")); err != nil { + t.Fatalf("remove index: %v", err) + } + + reloaded := NewSessionManager(storage) + if got := reloaded.GetSummary(key); !strings.Contains(got, "Previous deploy steps") { + t.Fatalf("expected summary recovered from fallback index, got %q", got) + } + results := reloaded.Search("deploy", nil, "", 5) + if len(results) != 1 || results[0].Key != key { + t.Fatalf("expected rebuilt search result, got %#v", results) + } +} + +func TestSessionManagerRollsOverActiveSegment(t *testing.T) { + t.Setenv("CLAWGO_SESSION_SEGMENT_MAX_MESSAGES", "2") + storage := t.TempDir() + sm := NewSessionManager(storage) + key := "cli:rollover" + + sm.AddMessage(key, "user", "one") + sm.AddMessage(key, "assistant", "two") + sm.AddMessage(key, "user", "three") + + if _, err := os.Stat(filepath.Join(storage, key+".0001.jsonl")); err != nil { + t.Fatalf("expected archived segment: %v", err) + } + if _, err := os.Stat(filepath.Join(storage, key+".active.jsonl")); err != nil { + t.Fatalf("expected new active segment: %v", err) + } + history := sm.GetHistory(key) + if len(history) != 3 { + t.Fatalf("expected full history across segments, got %d", len(history)) + } +} + +func TestSessionManagerHistoryWindowAndIncrementalIndex(t *testing.T) { + t.Setenv("CLAWGO_SESSION_SEGMENT_MAX_MESSAGES", "2") + storage := t.TempDir() + sm := NewSessionManager(storage) + key := "cli:window" + + sm.AddMessage(key, "user", "one") + sm.AddMessage(key, "assistant", "two") + sm.AddMessage(key, "user", "three") + sm.AddMessage(key, "assistant", "four") + + window := sm.GetHistoryWindow(key, 0, 0, 2, 2) + if len(window) != 2 || window[0].Content != "three" || window[1].Content != "four" { + t.Fatalf("unexpected history window: %#v", window) + } + + var index sessionIndexFile + if err := jsonlog.ReadJSON(filepath.Join(storage, key+".index.json"), &index); err != nil { + t.Fatalf("read index: %v", err) + } + size, err := jsonlog.FileSize(filepath.Join(storage, key+".active.jsonl")) + if err != nil { + t.Fatalf("file size: %v", err) + } + if index.LastSeq != 4 { + t.Fatalf("expected last seq 4, got %d", index.LastSeq) + } + if index.LastOffset != size { + t.Fatalf("expected last offset %d, got %d", size, index.LastOffset) + } + if index.Segment != key+".active.jsonl" { + t.Fatalf("unexpected index segment %q", index.Segment) + } +} + +func TestSessionManagerSearchSupportsChineseBigrams(t *testing.T) { + t.Parallel() + + storage := t.TempDir() + sm := NewSessionManager(storage) + key := "cli:zh" + + sm.AddMessage(key, "user", "之前讨论过发布回滚方案") + sm.AddMessage(key, "assistant", "回滚需要先确认数据库版本") + + results := sm.Search("回滚方案", nil, "", 5) + if len(results) != 1 || results[0].Key != key { + t.Fatalf("expected chinese query to hit sidecar index, got %#v", results) + } + + if err := os.Remove(filepath.Join(storage, key+".index.json")); err != nil { + t.Fatalf("remove index: %v", err) + } + reloaded := NewSessionManager(storage) + results = reloaded.Search("回滚方案", nil, "", 5) + if len(results) != 1 || results[0].Key != key { + t.Fatalf("expected chinese query to hit scan fallback, got %#v", results) + } +} + +func TestApplyCompactionIfUnchangedRejectsChangedSession(t *testing.T) { + t.Parallel() + + sm := NewSessionManager(t.TempDir()) + key := "cli:guard" + sm.AddMessage(key, "user", "one") + sm.AddMessage(key, "assistant", "two") + snapshot := sm.CompactionSnapshot(key) + sm.AddMessage(key, "user", "three") + + applied := sm.ApplyCompactionIfUnchanged(key, snapshot.NextSeq, snapshot.Summary, []providers.Message{{Role: "assistant", Content: "two"}}, "Key Facts\n- compacted") + if applied { + t.Fatal("expected stale compaction application to be rejected") + } + history := sm.GetPromptHistory(key) + if len(history) != 3 || history[2].Content != "three" { + t.Fatalf("expected newer message to remain, got %#v", history) + } +} + +func TestSessionManagerAppendDoesNotRewriteSessionsIndex(t *testing.T) { + t.Parallel() + + storage := t.TempDir() + sm := NewSessionManager(storage) + key := "cli:index" + + sm.AddMessage(key, "user", "first") + indexPath := filepath.Join(storage, "sessions.json") + before, err := os.ReadFile(indexPath) + if err != nil { + t.Fatalf("read sessions index: %v", err) + } + statBefore, err := os.Stat(indexPath) + if err != nil { + t.Fatalf("stat sessions index: %v", err) + } + + time.Sleep(10 * time.Millisecond) + sm.AddMessage(key, "assistant", "second") + + after, err := os.ReadFile(indexPath) + if err != nil { + t.Fatalf("read sessions index after append: %v", err) + } + statAfter, err := os.Stat(indexPath) + if err != nil { + t.Fatalf("stat sessions index after append: %v", err) + } + if string(before) != string(after) { + t.Fatalf("expected sessions.json to stay unchanged for hot-path append") + } + if !statAfter.ModTime().Equal(statBefore.ModTime()) { + t.Fatalf("expected sessions.json mtime to stay unchanged, before=%s after=%s", statBefore.ModTime(), statAfter.ModTime()) + } +} diff --git a/pkg/tools/runtime_types.go b/pkg/tools/runtime_types.go index feac799..a07fba8 100644 --- a/pkg/tools/runtime_types.go +++ b/pkg/tools/runtime_types.go @@ -66,16 +66,17 @@ type RunRecord struct { } type EventRecord struct { - ID string `json:"id,omitempty"` - RunID string `json:"run_id,omitempty"` - RequestID string `json:"request_id,omitempty"` - AgentID string `json:"agent_id,omitempty"` - Type string `json:"type"` - Status string `json:"status,omitempty"` - Message string `json:"message,omitempty"` - RetryCount int `json:"retry_count,omitempty"` - Error *RuntimeError `json:"error,omitempty"` - At int64 `json:"ts"` + ID string `json:"id,omitempty"` + RunID string `json:"run_id,omitempty"` + RequestID string `json:"request_id,omitempty"` + AgentID string `json:"agent_id,omitempty"` + Type string `json:"type"` + Status string `json:"status,omitempty"` + FailureCode string `json:"failure_code,omitempty"` + Message string `json:"message,omitempty"` + RetryCount int `json:"retry_count,omitempty"` + Error *RuntimeError `json:"error,omitempty"` + At int64 `json:"ts"` } type ArtifactRecord struct { diff --git a/pkg/tools/session_search.go b/pkg/tools/session_search.go new file mode 100644 index 0000000..cf9ace6 --- /dev/null +++ b/pkg/tools/session_search.go @@ -0,0 +1,103 @@ +package tools + +import ( + "context" + "fmt" + "strings" + + "github.com/YspCoder/clawgo/pkg/session" +) + +type SessionSearchTool struct { + manager *session.SessionManager +} + +func NewSessionSearchTool(manager *session.SessionManager) *SessionSearchTool { + return &SessionSearchTool{manager: manager} +} + +func (t *SessionSearchTool) Name() string { return "session_search" } + +func (t *SessionSearchTool) Description() string { + return "Search past session history across JSONL session logs. Use when the user refers to previous conversations, past decisions, or earlier project work." +} + +func (t *SessionSearchTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "query": map[string]interface{}{ + "type": "string", + "description": "Keywords or short phrase to search across past sessions.", + }, + "limit": map[string]interface{}{ + "type": "integer", + "description": "Maximum number of matching sessions to return.", + "default": 5, + }, + "kinds": map[string]interface{}{ + "type": "array", + "description": "Optional session kinds filter, e.g. main, cron, subagent, hook.", + "items": map[string]interface{}{"type": "string"}, + }, + "exclude_current": map[string]interface{}{ + "type": "boolean", + "description": "Exclude the current session from results when session_key is provided.", + "default": true, + }, + "session_key": map[string]interface{}{ + "type": "string", + "description": "Optional current session key for exclude_current behavior.", + }, + }, + "required": []string{"query"}, + } +} + +func (t *SessionSearchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { + _ = ctx + if t == nil || t.manager == nil { + return "", fmt.Errorf("session manager not configured") + } + query := MapStringArg(args, "query") + if query == "" { + return "", fmt.Errorf("query is required") + } + limit := MapIntArg(args, "limit", 5) + kinds := MapStringListArg(args, "kinds") + excludeCurrent, excludeSet := MapBoolArg(args, "exclude_current") + excludeKey := "" + if excludeSet && excludeCurrent { + excludeKey = MapStringArg(args, "session_key") + } + + results := t.manager.Search(query, kinds, excludeKey, limit) + if len(results) == 0 { + return fmt.Sprintf("No past sessions matched %q.", query), nil + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Session search results for %q:\n\n", query)) + for _, item := range results { + sb.WriteString(fmt.Sprintf("- %s kind=%s updated=%s score=%d\n", item.Key, item.Kind, item.UpdatedAt.Format("2006-01-02 15:04:05"), item.Score)) + if summary := strings.TrimSpace(item.Summary); summary != "" { + sb.WriteString(" summary: " + singleLine(summary, 220) + "\n") + } + for _, snippet := range item.Snippets { + line := singleLine(snippet.Content, 220) + sb.WriteString(fmt.Sprintf(" [#%d][%s] %s\n", snippet.Seq, snippet.Role, line)) + } + sb.WriteString("\n") + } + return strings.TrimSpace(sb.String()), nil +} + +func (t *SessionSearchTool) ParallelSafe() bool { return true } + +func singleLine(s string, max int) string { + s = strings.TrimSpace(strings.ReplaceAll(s, "\n", " ")) + if max > 0 && len(s) > max { + return s[:max] + "..." + } + return s +} diff --git a/pkg/tools/spawn.go b/pkg/tools/spawn.go index d3fd356..366cbb9 100644 --- a/pkg/tools/spawn.go +++ b/pkg/tools/spawn.go @@ -61,6 +61,10 @@ func (t *SpawnTool) Parameters() map[string]interface{} { "type": "integer", "description": "Optional per-attempt timeout in seconds.", }, + "max_tool_iterations": map[string]interface{}{ + "type": "integer", + "description": "Optional independent tool-calling iteration budget.", + }, "max_task_chars": map[string]interface{}{ "type": "integer", "description": "Optional task size quota in characters.", @@ -101,6 +105,7 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (s maxRetries := MapIntArg(args, "max_retries", 0) retryBackoff := MapIntArg(args, "retry_backoff_ms", 0) timeoutSec := MapIntArg(args, "timeout_sec", 0) + maxToolIterations := MapIntArg(args, "max_tool_iterations", 0) maxTaskChars := MapIntArg(args, "max_task_chars", 0) maxResultChars := MapIntArg(args, "max_result_chars", 0) if label == "" && role != "" { @@ -129,17 +134,18 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (s } result, err := t.manager.Spawn(ctx, SubagentSpawnOptions{ - Task: task, - Label: label, - Role: role, - AgentID: agentID, - MaxRetries: maxRetries, - RetryBackoff: retryBackoff, - TimeoutSec: timeoutSec, - MaxTaskChars: maxTaskChars, - MaxResultChars: maxResultChars, - OriginChannel: originChannel, - OriginChatID: originChatID, + Task: task, + Label: label, + Role: role, + AgentID: agentID, + MaxRetries: maxRetries, + RetryBackoff: retryBackoff, + TimeoutSec: timeoutSec, + MaxToolIterations: maxToolIterations, + MaxTaskChars: maxTaskChars, + MaxResultChars: maxResultChars, + OriginChannel: originChannel, + OriginChatID: originChatID, }) if err != nil { return "", fmt.Errorf("failed to spawn subagent: %w", err) diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 6854a48..717d97a 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -15,37 +15,42 @@ import ( ) type SubagentRun struct { - ID string `json:"id"` - Task string `json:"task"` - Label string `json:"label"` - Role string `json:"role"` - AgentID string `json:"agent_id"` - Transport string `json:"transport,omitempty"` - ParentAgentID string `json:"parent_agent_id,omitempty"` - NotifyMainPolicy string `json:"notify_main_policy,omitempty"` - SessionKey string `json:"session_key"` - MemoryNS string `json:"memory_ns"` - SystemPromptFile string `json:"system_prompt_file,omitempty"` - ToolAllowlist []string `json:"tool_allowlist,omitempty"` - MaxRetries int `json:"max_retries,omitempty"` - RetryBackoff int `json:"retry_backoff,omitempty"` - TimeoutSec int `json:"timeout_sec,omitempty"` - MaxTaskChars int `json:"max_task_chars,omitempty"` - MaxResultChars int `json:"max_result_chars,omitempty"` - RetryCount int `json:"retry_count,omitempty"` - ThreadID string `json:"thread_id,omitempty"` - CorrelationID string `json:"correlation_id,omitempty"` - ParentRunID string `json:"parent_run_id,omitempty"` - LastMessageID string `json:"last_message_id,omitempty"` - WaitingReply bool `json:"waiting_for_reply,omitempty"` - SharedState map[string]interface{} `json:"shared_state,omitempty"` - OriginChannel string `json:"origin_channel,omitempty"` - OriginChatID string `json:"origin_chat_id,omitempty"` - Status string `json:"status"` - Result string `json:"result,omitempty"` - Steering []string `json:"steering,omitempty"` - Created int64 `json:"created"` - Updated int64 `json:"updated"` + ID string `json:"id"` + Task string `json:"task"` + Label string `json:"label"` + Role string `json:"role"` + AgentID string `json:"agent_id"` + Transport string `json:"transport,omitempty"` + ParentAgentID string `json:"parent_agent_id,omitempty"` + NotifyMainPolicy string `json:"notify_main_policy,omitempty"` + SessionKey string `json:"session_key"` + MemoryNS string `json:"memory_ns"` + SystemPromptFile string `json:"system_prompt_file,omitempty"` + ToolAllowlist []string `json:"tool_allowlist,omitempty"` + MaxRetries int `json:"max_retries,omitempty"` + RetryBackoff int `json:"retry_backoff,omitempty"` + TimeoutSec int `json:"timeout_sec,omitempty"` + MaxToolIterations int `json:"max_tool_iterations,omitempty"` + MaxTaskChars int `json:"max_task_chars,omitempty"` + MaxResultChars int `json:"max_result_chars,omitempty"` + RetryCount int `json:"retry_count,omitempty"` + IterationCount int `json:"iteration_count,omitempty"` + AttemptCount int `json:"attempt_count,omitempty"` + RestartCount int `json:"restart_count,omitempty"` + LastFailureCode string `json:"last_failure_code,omitempty"` + ThreadID string `json:"thread_id,omitempty"` + CorrelationID string `json:"correlation_id,omitempty"` + ParentRunID string `json:"parent_run_id,omitempty"` + LastMessageID string `json:"last_message_id,omitempty"` + WaitingReply bool `json:"waiting_for_reply,omitempty"` + SharedState map[string]interface{} `json:"shared_state,omitempty"` + OriginChannel string `json:"origin_channel,omitempty"` + OriginChatID string `json:"origin_chat_id,omitempty"` + Status string `json:"status"` + Result string `json:"result,omitempty"` + Steering []string `json:"steering,omitempty"` + Created int64 `json:"created"` + Updated int64 `json:"updated"` } type SubagentManager struct { @@ -66,21 +71,22 @@ type SubagentManager struct { } type SubagentSpawnOptions struct { - Task string - Label string - Role string - AgentID string - NotifyMainPolicy string - MaxRetries int - RetryBackoff int - TimeoutSec int - MaxTaskChars int - MaxResultChars int - OriginChannel string - OriginChatID string - ThreadID string - CorrelationID string - ParentRunID string + Task string + Label string + Role string + AgentID string + NotifyMainPolicy string + MaxRetries int + RetryBackoff int + TimeoutSec int + MaxToolIterations int + MaxTaskChars int + MaxResultChars int + OriginChannel string + OriginChatID string + ThreadID string + CorrelationID string + ParentRunID string } func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *bus.MessageBus) *SubagentManager { @@ -173,6 +179,7 @@ func (sm *SubagentManager) spawnRun(ctx context.Context, opts SubagentSpawnOptio maxRetries := 0 retryBackoff := 1000 timeoutSec := 0 + maxToolIterations := 0 maxTaskChars := 0 maxResultChars := 0 if profile == nil && sm.profileStore != nil { @@ -206,6 +213,7 @@ func (sm *SubagentManager) spawnRun(ctx context.Context, opts SubagentSpawnOptio maxRetries = profile.MaxRetries retryBackoff = profile.RetryBackoff timeoutSec = profile.TimeoutSec + maxToolIterations = profile.MaxToolIterations maxTaskChars = profile.MaxTaskChars maxResultChars = profile.MaxResultChars } @@ -218,6 +226,9 @@ func (sm *SubagentManager) spawnRun(ctx context.Context, opts SubagentSpawnOptio if opts.TimeoutSec > 0 { timeoutSec = opts.TimeoutSec } + if opts.MaxToolIterations > 0 { + maxToolIterations = opts.MaxToolIterations + } if opts.MaxTaskChars > 0 { maxTaskChars = opts.MaxTaskChars } @@ -230,6 +241,7 @@ func (sm *SubagentManager) spawnRun(ctx context.Context, opts SubagentSpawnOptio maxRetries = normalizePositiveBound(maxRetries, 0, 8) retryBackoff = normalizePositiveBound(retryBackoff, 500, 120000) timeoutSec = normalizePositiveBound(timeoutSec, 0, 3600) + maxToolIterations = normalizePositiveBound(maxToolIterations, 0, 200) maxTaskChars = normalizePositiveBound(maxTaskChars, 0, 400000) maxResultChars = normalizePositiveBound(maxResultChars, 0, 400000) if role == "" { @@ -270,32 +282,33 @@ func (sm *SubagentManager) spawnRun(ctx context.Context, opts SubagentSpawnOptio } } subagentRun := &SubagentRun{ - ID: runID, - Task: task, - Label: label, - Role: role, - AgentID: agentID, - Transport: transport, - ParentAgentID: parentAgentID, - NotifyMainPolicy: notifyMainPolicy, - SessionKey: sessionKey, - MemoryNS: memoryNS, - SystemPromptFile: systemPromptFile, - ToolAllowlist: toolAllowlist, - MaxRetries: maxRetries, - RetryBackoff: retryBackoff, - TimeoutSec: timeoutSec, - MaxTaskChars: maxTaskChars, - MaxResultChars: maxResultChars, - RetryCount: 0, - ThreadID: threadID, - CorrelationID: correlationID, - ParentRunID: parentRunID, - OriginChannel: originChannel, - OriginChatID: originChatID, - Status: RuntimeStatusRouting, - Created: now, - Updated: now, + ID: runID, + Task: task, + Label: label, + Role: role, + AgentID: agentID, + Transport: transport, + ParentAgentID: parentAgentID, + NotifyMainPolicy: notifyMainPolicy, + SessionKey: sessionKey, + MemoryNS: memoryNS, + SystemPromptFile: systemPromptFile, + ToolAllowlist: toolAllowlist, + MaxRetries: maxRetries, + RetryBackoff: retryBackoff, + TimeoutSec: timeoutSec, + MaxToolIterations: maxToolIterations, + MaxTaskChars: maxTaskChars, + MaxResultChars: maxResultChars, + RetryCount: 0, + ThreadID: threadID, + CorrelationID: correlationID, + ParentRunID: parentRunID, + OriginChannel: originChannel, + OriginChatID: originChatID, + Status: RuntimeStatusRouting, + Created: now, + Updated: now, } taskCtx, cancel := context.WithCancel(ctx) sm.runs[runID] = subagentRun @@ -335,6 +348,7 @@ func (sm *SubagentManager) runSubagent(ctx context.Context, run *SubagentRun) { sm.mu.Lock() if runErr != nil { run.Status = RuntimeStatusFailed + run.LastFailureCode = classifySubagentFailureCode(runErr) run.Result = fmt.Sprintf("Error: %v", runErr) run.Result = applySubagentResultQuota(run.Result, run.MaxResultChars) run.Updated = time.Now().UnixMilli() @@ -354,6 +368,7 @@ func (sm *SubagentManager) runSubagent(ctx context.Context, run *SubagentRun) { sm.notifyRunWaitersLocked(run.ID) } else { run.Status = RuntimeStatusCompleted + run.LastFailureCode = "" run.Result = applySubagentResultQuota(result, run.MaxResultChars) run.Updated = time.Now().UnixMilli() run.WaitingReply = false @@ -383,16 +398,17 @@ func (sm *SubagentManager) runSubagent(ctx context.Context, run *SubagentRun) { SessionKey: run.SessionKey, Content: announceContent, Metadata: map[string]string{ - "trigger": "subagent", - "subagent_id": run.ID, - "agent_id": run.AgentID, - "role": run.Role, - "session_key": run.SessionKey, - "memory_ns": run.MemoryNS, - "retry_count": fmt.Sprintf("%d", run.RetryCount), - "timeout_sec": fmt.Sprintf("%d", run.TimeoutSec), - "status": run.Status, - "notify_reason": notifyReason, + "trigger": "subagent", + "subagent_id": run.ID, + "agent_id": run.AgentID, + "role": run.Role, + "session_key": run.SessionKey, + "memory_ns": run.MemoryNS, + "retry_count": fmt.Sprintf("%d", run.RetryCount), + "iteration_count": fmt.Sprintf("%d", run.IterationCount), + "timeout_sec": fmt.Sprintf("%d", run.TimeoutSec), + "status": run.Status, + "notify_reason": notifyReason, }, }) } @@ -504,6 +520,10 @@ func (sm *SubagentManager) runWithRetry(ctx context.Context, run *SubagentRun) ( var lastErr error for attempt := 0; attempt <= maxRetries; attempt++ { + if remaining := sm.remainingIterations(run); remaining == 0 && run.MaxToolIterations > 0 { + run.LastFailureCode = "retry_limit" + return "", fmt.Errorf("subagent iteration budget exhausted") + } runCtx := ctx var cancel context.CancelFunc if timeoutSec > 0 { @@ -516,18 +536,20 @@ func (sm *SubagentManager) runWithRetry(ctx context.Context, run *SubagentRun) ( if err == nil { sm.mu.Lock() run.RetryCount = attempt + run.LastFailureCode = "" run.Updated = time.Now().UnixMilli() sm.persistRunLocked(run, "attempt_succeeded", "") sm.mu.Unlock() return result, nil } lastErr = err + run.LastFailureCode = classifySubagentFailureCode(err) sm.mu.Lock() run.RetryCount = attempt run.Updated = time.Now().UnixMilli() sm.persistRunLocked(run, "attempt_failed", err.Error()) sm.mu.Unlock() - if attempt >= maxRetries { + if attempt >= maxRetries || !shouldRetrySubagentError(err) { break } select { @@ -547,8 +569,14 @@ func (sm *SubagentManager) executeRunOnce(ctx context.Context, run *SubagentRun) return "", fmt.Errorf("subagent run is nil") } pending, consumedIDs := sm.consumeThreadInbox(run) + stats := &SubagentExecutionStats{} + ctx = WithSubagentExecutionStats(ctx, stats) if sm.runFunc != nil { + if remaining := sm.remainingIterations(run); remaining > 0 { + ctx = WithSubagentIterationBudget(ctx, remaining) + } result, err := sm.runFunc(ctx, run) + sm.applyExecutionStats(run, stats) if err != nil { sm.restoreMessageStatuses(consumedIDs) } else { @@ -582,6 +610,8 @@ func (sm *SubagentManager) executeRunOnce(ctx context.Context, run *SubagentRun) response, err := sm.provider.Chat(ctx, messages, nil, sm.provider.GetDefaultModel(), map[string]interface{}{ "max_tokens": 4096, }) + stats.Attempts++ + sm.applyExecutionStats(run, stats) if err != nil { sm.restoreMessageStatuses(consumedIDs) return "", err @@ -723,15 +753,16 @@ func (sm *SubagentManager) RuntimeSnapshot(limit int) RuntimeSnapshot { if evts, err := sm.Events(run.ID, limit); err == nil { for _, evt := range evts { snapshot.Events = append(snapshot.Events, EventRecord{ - ID: EventRecordID(evt.RunID, evt.Type, evt.At), - RunID: evt.RunID, - RequestID: evt.RunID, - AgentID: evt.AgentID, - Type: evt.Type, - Status: evt.Status, - Message: evt.Message, - RetryCount: evt.RetryCount, - At: evt.At, + ID: EventRecordID(evt.RunID, evt.Type, evt.At), + RunID: evt.RunID, + RequestID: evt.RunID, + AgentID: evt.AgentID, + Type: evt.Type, + Status: evt.Status, + FailureCode: evt.FailureCode, + Message: evt.Message, + RetryCount: evt.RetryCount, + At: evt.At, }) } } @@ -823,6 +854,71 @@ func applySubagentResultQuota(result string, maxChars int) string { return strings.TrimSpace(trimmed) + suffix } +func (sm *SubagentManager) remainingIterations(run *SubagentRun) int { + if run == nil || run.MaxToolIterations <= 0 { + return 0 + } + remaining := run.MaxToolIterations - run.IterationCount + if remaining < 0 { + return 0 + } + return remaining +} + +func (sm *SubagentManager) applyExecutionStats(run *SubagentRun, stats *SubagentExecutionStats) { + if run == nil || stats == nil { + return + } + run.IterationCount += stats.Iterations + run.AttemptCount += stats.Attempts + run.RestartCount += stats.Restarts + if strings.TrimSpace(stats.FailureCode) != "" { + run.LastFailureCode = strings.TrimSpace(stats.FailureCode) + } +} + +func shouldRetrySubagentError(err error) bool { + code := classifySubagentFailureCode(err) + switch code { + case "", "timeout", "stream_failed", "stream_stale", "context_compacted": + return true + case "continuation_exhausted", "retry_limit": + return false + default: + return !errors.Is(err, context.Canceled) + } +} + +func classifySubagentFailureCode(err error) string { + if err == nil { + return "" + } + if errors.Is(err, context.DeadlineExceeded) { + return "timeout" + } + var execErr *providers.ProviderExecutionError + if errors.As(err, &execErr) && execErr != nil { + if strings.TrimSpace(execErr.Code) != "" { + return strings.TrimSpace(execErr.Code) + } + } + lower := strings.ToLower(strings.TrimSpace(err.Error())) + switch { + case strings.Contains(lower, "max tool iterations"), strings.Contains(lower, "iteration budget exhausted"): + return "retry_limit" + case strings.Contains(lower, "stream stale"): + return "stream_stale" + case strings.Contains(lower, "stream failed"): + return "stream_failed" + case strings.Contains(lower, "continuation exhausted"), strings.Contains(lower, "thinking budget exhausted"): + return "continuation_exhausted" + case strings.Contains(lower, "compaction"): + return "context_compacted" + default: + return "" + } +} + func normalizeSubagentIdentifier(in string) string { in = strings.TrimSpace(strings.ToLower(in)) if in == "" { @@ -867,13 +963,14 @@ func (sm *SubagentManager) persistRunLocked(run *SubagentRun, eventType, message cp := cloneSubagentRun(run) _ = sm.runStore.AppendRun(cp) _ = sm.runStore.AppendEvent(SubagentRunEvent{ - RunID: cp.ID, - AgentID: cp.AgentID, - Type: strings.TrimSpace(eventType), - Status: cp.Status, - Message: strings.TrimSpace(message), - RetryCount: cp.RetryCount, - At: cp.Updated, + RunID: cp.ID, + AgentID: cp.AgentID, + Type: strings.TrimSpace(eventType), + Status: cp.Status, + FailureCode: strings.TrimSpace(cp.LastFailureCode), + Message: strings.TrimSpace(message), + RetryCount: cp.RetryCount, + At: cp.Updated, }) } diff --git a/pkg/tools/subagent_budget_test.go b/pkg/tools/subagent_budget_test.go new file mode 100644 index 0000000..6ce147d --- /dev/null +++ b/pkg/tools/subagent_budget_test.go @@ -0,0 +1,123 @@ +package tools + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/YspCoder/clawgo/pkg/providers" +) + +func TestSubagentRunPreservesRemainingIterationBudgetAcrossRetry(t *testing.T) { + t.Parallel() + + manager := NewSubagentManager(nil, t.TempDir(), nil) + attempts := 0 + manager.SetRunFunc(func(ctx context.Context, run *SubagentRun) (string, error) { + attempts++ + budget, ok := SubagentIterationBudget(ctx) + if !ok { + t.Fatal("expected subagent iteration budget in context") + } + if attempts == 1 { + if budget != 3 { + t.Fatalf("expected first attempt budget 3, got %d", budget) + } + RecordSubagentExecutionStats(ctx, SubagentExecutionStats{ + Iterations: 2, + Attempts: 1, + FailureCode: "stream_failed", + }) + return "", providers.NewProviderExecutionError("stream_failed", "stream failed", "stream", true, "test") + } + if budget != 1 { + t.Fatalf("expected retry to inherit remaining budget 1, got %d", budget) + } + RecordSubagentExecutionStats(ctx, SubagentExecutionStats{ + Iterations: 1, + Attempts: 1, + }) + return "done", nil + }) + + run, err := manager.SpawnRun(context.Background(), SubagentSpawnOptions{ + Task: "finish task", + AgentID: "tester", + MaxRetries: 1, + RetryBackoff: 1, + MaxToolIterations: 3, + }) + if err != nil { + t.Fatalf("spawn run: %v", err) + } + + waitCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + finalRun, _, err := manager.waitRun(waitCtx, run.ID) + if err != nil { + t.Fatalf("wait run: %v", err) + } + if finalRun.Status != RuntimeStatusCompleted { + t.Fatalf("expected completed run, got %+v", finalRun) + } + if finalRun.IterationCount != 3 { + t.Fatalf("expected 3 consumed iterations, got %d", finalRun.IterationCount) + } + if finalRun.AttemptCount != 2 { + t.Fatalf("expected 2 attempts, got %d", finalRun.AttemptCount) + } + events, err := manager.Events(run.ID, 10) + if err != nil { + t.Fatalf("events: %v", err) + } + if len(events) == 0 { + t.Fatal("expected persisted events") + } + foundFailure := false + for _, evt := range events { + if evt.Type == "attempt_failed" && evt.FailureCode == "stream_failed" { + foundFailure = true + } + } + if !foundFailure { + t.Fatalf("expected stream_failed event, got %#v", events) + } +} + +func TestSubagentRunStopsWhenIterationBudgetExhausted(t *testing.T) { + t.Parallel() + + manager := NewSubagentManager(nil, t.TempDir(), nil) + manager.SetRunFunc(func(ctx context.Context, run *SubagentRun) (string, error) { + RecordSubagentExecutionStats(ctx, SubagentExecutionStats{Iterations: 2, Attempts: 1}) + return "", errors.New("max tool iterations exceeded") + }) + + run, err := manager.SpawnRun(context.Background(), SubagentSpawnOptions{ + Task: "exhaust budget", + AgentID: "tester", + MaxRetries: 2, + RetryBackoff: 1, + MaxToolIterations: 2, + }) + if err != nil { + t.Fatalf("spawn run: %v", err) + } + + waitCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + finalRun, _, err := manager.waitRun(waitCtx, run.ID) + if err != nil { + t.Fatalf("wait run: %v", err) + } + if finalRun.Status != RuntimeStatusFailed { + t.Fatalf("expected failed run, got %+v", finalRun) + } + if finalRun.LastFailureCode != "retry_limit" { + t.Fatalf("expected retry_limit failure code, got %q", finalRun.LastFailureCode) + } + if finalRun.IterationCount != 2 { + t.Fatalf("expected consumed iterations to stay at 2, got %d", finalRun.IterationCount) + } +} diff --git a/pkg/tools/subagent_mailbox.go b/pkg/tools/subagent_mailbox.go index 9774eeb..21085f9 100644 --- a/pkg/tools/subagent_mailbox.go +++ b/pkg/tools/subagent_mailbox.go @@ -1,7 +1,6 @@ package tools import ( - "bufio" "encoding/json" "fmt" "os" @@ -10,6 +9,9 @@ import ( "strconv" "strings" "sync" + "time" + + "github.com/YspCoder/clawgo/pkg/jsonlog" ) type AgentThread struct { @@ -36,15 +38,39 @@ type AgentMessage struct { CreatedAt int64 `json:"created_at"` } +type mailboxThreadsMetaFile struct { + Version int `json:"version"` + LastOffset int64 `json:"last_offset,omitempty"` + ThreadSeq int `json:"thread_seq,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + Threads []AgentThread `json:"threads,omitempty"` +} + +type mailboxMessagesMetaFile struct { + Version int `json:"version"` + LastOffset int64 `json:"last_offset,omitempty"` + MsgSeq int `json:"msg_seq,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + Messages []AgentMessage `json:"messages,omitempty"` + ThreadMessages map[string][]string `json:"thread_messages,omitempty"` + QueuedByAgent map[string][]string `json:"queued_by_agent,omitempty"` + QueuedByThreadAgent map[string][]string `json:"queued_by_thread_agent,omitempty"` +} + type AgentMailboxStore struct { - dir string - threadsPath string - msgsPath string - mu sync.RWMutex - threads map[string]*AgentThread - messages map[string]*AgentMessage - msgSeq int - threadSeq int + dir string + threadsPath string + msgsPath string + threadsMetaPath string + msgsMetaPath string + mu sync.RWMutex + threads map[string]*AgentThread + messages map[string]*AgentMessage + threadMessages map[string][]string + queuedByAgent map[string][]string + queuedByThreadAgent map[string][]string + msgSeq int + threadSeq int } func NewAgentMailboxStore(workspace string) *AgentMailboxStore { @@ -54,13 +80,18 @@ func NewAgentMailboxStore(workspace string) *AgentMailboxStore { } dir := filepath.Join(workspace, "agents", "runtime") s := &AgentMailboxStore{ - dir: dir, - threadsPath: filepath.Join(dir, "threads.jsonl"), - msgsPath: filepath.Join(dir, "agent_messages.jsonl"), - threads: map[string]*AgentThread{}, - messages: map[string]*AgentMessage{}, + dir: dir, + threadsPath: filepath.Join(dir, "threads.jsonl"), + msgsPath: filepath.Join(dir, "agent_messages.jsonl"), + threadsMetaPath: filepath.Join(dir, "threads.meta.json"), + msgsMetaPath: filepath.Join(dir, "agent_messages.meta.json"), + threads: map[string]*AgentThread{}, + messages: map[string]*AgentMessage{}, + threadMessages: map[string][]string{}, + queuedByAgent: map[string][]string{}, + queuedByThreadAgent: map[string][]string{}, } - _ = os.MkdirAll(dir, 0755) + _ = os.MkdirAll(dir, 0o755) _ = s.load() return s } @@ -68,75 +99,91 @@ func NewAgentMailboxStore(workspace string) *AgentMailboxStore { func (s *AgentMailboxStore) load() error { s.mu.Lock() defer s.mu.Unlock() - s.threads = map[string]*AgentThread{} - s.messages = map[string]*AgentMessage{} + s.resetLocked() if err := s.loadThreadsLocked(); err != nil { return err } - return s.scanMessagesLocked() + return s.loadMessagesLocked() +} + +func (s *AgentMailboxStore) resetLocked() { + s.threads = map[string]*AgentThread{} + s.messages = map[string]*AgentMessage{} + s.threadMessages = map[string][]string{} + s.queuedByAgent = map[string][]string{} + s.queuedByThreadAgent = map[string][]string{} + s.msgSeq = 0 + s.threadSeq = 0 } func (s *AgentMailboxStore) loadThreadsLocked() error { - f, err := os.Open(s.threadsPath) + size, err := jsonlog.FileSize(s.threadsPath) if err != nil { - if os.IsNotExist(err) { - return nil - } return err } - defer f.Close() - scanner := bufio.NewScanner(f) - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 2*1024*1024) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" { - continue + var meta mailboxThreadsMetaFile + if err := jsonlog.ReadJSON(s.threadsMetaPath, &meta); err == nil && meta.Version > 0 && meta.LastOffset == size { + for _, thread := range meta.Threads { + cp := thread + cp.Participants = append([]string(nil), thread.Participants...) + s.threads[cp.ThreadID] = &cp } + s.threadSeq = meta.ThreadSeq + return nil + } + if size == 0 { + return s.persistThreadsMetaLocked(size) + } + if err := jsonlog.Scan(s.threadsPath, func(line []byte) error { var thread AgentThread - if err := json.Unmarshal([]byte(line), &thread); err != nil { - continue + if err := json.Unmarshal(line, &thread); err != nil || strings.TrimSpace(thread.ThreadID) == "" { + return nil } cp := thread + cp.Participants = append([]string(nil), thread.Participants...) s.threads[thread.ThreadID] = &cp if n := parseThreadSequence(thread.ThreadID); n > s.threadSeq { s.threadSeq = n } - } - return scanner.Err() -} - -func (s *AgentMailboxStore) scanMessagesLocked() error { - f, err := os.Open(s.msgsPath) - if err != nil { - if os.IsNotExist(err) { - return nil - } + return nil + }); err != nil { return err } - defer f.Close() - scanner := bufio.NewScanner(f) - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 2*1024*1024) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" { - continue - } - var msg AgentMessage - if err := json.Unmarshal([]byte(line), &msg); err != nil { - continue - } - if n := parseMessageSequence(msg.MessageID); n > s.msgSeq { - s.msgSeq = n - } - cp := msg - s.messages[msg.MessageID] = &cp - if thread := s.threads[msg.ThreadID]; thread != nil && msg.CreatedAt > thread.UpdatedAt { - thread.UpdatedAt = msg.CreatedAt - } + return s.persistThreadsMetaLocked(size) +} + +func (s *AgentMailboxStore) loadMessagesLocked() error { + size, err := jsonlog.FileSize(s.msgsPath) + if err != nil { + return err } - return scanner.Err() + var meta mailboxMessagesMetaFile + if err := jsonlog.ReadJSON(s.msgsMetaPath, &meta); err == nil && meta.Version > 0 && meta.LastOffset == size { + for _, msg := range meta.Messages { + cp := msg + s.messages[msg.MessageID] = &cp + } + s.threadMessages = cloneStringSliceMap(meta.ThreadMessages) + s.queuedByAgent = cloneStringSliceMap(meta.QueuedByAgent) + s.queuedByThreadAgent = cloneStringSliceMap(meta.QueuedByThreadAgent) + s.msgSeq = meta.MsgSeq + s.reconcileThreadsFromMessagesLocked() + return nil + } + if size == 0 { + return s.persistMessagesMetaLocked(size) + } + if err := jsonlog.Scan(s.msgsPath, func(line []byte) error { + var msg AgentMessage + if err := json.Unmarshal(line, &msg); err != nil || strings.TrimSpace(msg.MessageID) == "" { + return nil + } + s.indexMessageLocked(msg) + return nil + }); err != nil { + return err + } + return s.persistMessagesMetaLocked(size) } func (s *AgentMailboxStore) EnsureThread(thread AgentThread) (AgentThread, error) { @@ -145,7 +192,7 @@ func (s *AgentMailboxStore) EnsureThread(thread AgentThread) (AgentThread, error } s.mu.Lock() defer s.mu.Unlock() - if err := os.MkdirAll(s.dir, 0755); err != nil { + if err := os.MkdirAll(s.dir, 0o755); err != nil { return AgentThread{}, err } if strings.TrimSpace(thread.ThreadID) == "" { @@ -165,20 +212,19 @@ func (s *AgentMailboxStore) EnsureThread(thread AgentThread) (AgentThread, error if thread.UpdatedAt <= 0 { thread.UpdatedAt = thread.CreatedAt } - data, err := json.Marshal(thread) + offset, err := jsonlog.AppendLine(s.threadsPath, thread) if err != nil { return AgentThread{}, err } - f, err := os.OpenFile(s.threadsPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) - if err != nil { - return AgentThread{}, err - } - defer f.Close() - if _, err := f.Write(append(data, '\n')); err != nil { - return AgentThread{}, err - } cp := thread + cp.Participants = append([]string(nil), thread.Participants...) s.threads[thread.ThreadID] = &cp + if n := parseThreadSequence(thread.ThreadID); n > s.threadSeq { + s.threadSeq = n + } + if err := s.persistThreadsMetaLocked(offset); err != nil { + return AgentThread{}, err + } return thread, nil } @@ -188,7 +234,7 @@ func (s *AgentMailboxStore) AppendMessage(msg AgentMessage) (AgentMessage, error } s.mu.Lock() defer s.mu.Unlock() - if err := os.MkdirAll(s.dir, 0755); err != nil { + if err := os.MkdirAll(s.dir, 0o755); err != nil { return AgentMessage{}, err } if strings.TrimSpace(msg.MessageID) == "" { @@ -198,26 +244,17 @@ func (s *AgentMailboxStore) AppendMessage(msg AgentMessage) (AgentMessage, error if strings.TrimSpace(msg.Status) == "" { msg.Status = "queued" } - data, err := json.Marshal(msg) + offset, err := jsonlog.AppendLine(s.msgsPath, msg) if err != nil { return AgentMessage{}, err } - f, err := os.OpenFile(s.msgsPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) - if err != nil { + s.indexMessageLocked(msg) + if err := s.persistMessagesMetaLocked(offset); err != nil { return AgentMessage{}, err } - defer f.Close() - if _, err := f.Write(append(data, '\n')); err != nil { + if err := s.persistThreadsMetaLocked(0); err != nil { return AgentMessage{}, err } - if thread := s.threads[msg.ThreadID]; thread != nil { - thread.UpdatedAt = msg.CreatedAt - participants := append([]string(nil), thread.Participants...) - participants = append(participants, msg.FromAgent, msg.ToAgent) - thread.Participants = normalizeStringList(participants) - } - cp := msg - s.messages[msg.MessageID] = &cp return msg, nil } @@ -240,30 +277,21 @@ func (s *AgentMailboxStore) MessagesByThread(threadID string, limit int) ([]Agen if s == nil { return nil, nil } - return s.currentMessages(func(msg AgentMessage) bool { - return msg.ThreadID == strings.TrimSpace(threadID) - }, limit), nil + return s.currentIndexedMessages(s.threadMessages[strings.TrimSpace(threadID)], limit), nil } func (s *AgentMailboxStore) Inbox(agentID string, limit int) ([]AgentMessage, error) { if s == nil { return nil, nil } - agentID = strings.TrimSpace(agentID) - return s.currentMessages(func(msg AgentMessage) bool { - return msg.ToAgent == agentID && strings.EqualFold(strings.TrimSpace(msg.Status), "queued") - }, limit), nil + return s.currentIndexedMessages(s.queuedByAgent[strings.TrimSpace(agentID)], limit), nil } func (s *AgentMailboxStore) ThreadInbox(threadID, agentID string, limit int) ([]AgentMessage, error) { if s == nil { return nil, nil } - threadID = strings.TrimSpace(threadID) - agentID = strings.TrimSpace(agentID) - return s.currentMessages(func(msg AgentMessage) bool { - return msg.ThreadID == threadID && msg.ToAgent == agentID && strings.EqualFold(strings.TrimSpace(msg.Status), "queued") - }, limit), nil + return s.currentIndexedMessages(s.queuedByThreadAgent[mailboxThreadAgentKey(threadID, agentID)], limit), nil } func (s *AgentMailboxStore) Message(messageID string) (*AgentMessage, bool) { @@ -305,16 +333,16 @@ func (s *AgentMailboxStore) UpdateMessageStatus(messageID, status string, at int return &msg, nil } -func (s *AgentMailboxStore) currentMessages(match func(AgentMessage) bool, limit int) []AgentMessage { +func (s *AgentMailboxStore) currentIndexedMessages(ids []string, limit int) []AgentMessage { s.mu.RLock() defer s.mu.RUnlock() - var out []AgentMessage - for _, item := range s.messages { - msg := *item - if match != nil && !match(msg) { + out := make([]AgentMessage, 0, len(ids)) + for _, id := range ids { + msg := s.messages[id] + if msg == nil { continue } - out = append(out, msg) + out = append(out, *msg) } sort.Slice(out, func(i, j int) bool { if out[i].CreatedAt != out[j].CreatedAt { @@ -328,6 +356,136 @@ func (s *AgentMailboxStore) currentMessages(match func(AgentMessage) bool, limit return out } +func (s *AgentMailboxStore) indexMessageLocked(msg AgentMessage) { + msg.MessageID = strings.TrimSpace(msg.MessageID) + if msg.MessageID == "" { + return + } + if n := parseMessageSequence(msg.MessageID); n > s.msgSeq { + s.msgSeq = n + } + if existing := s.messages[msg.MessageID]; existing != nil { + s.removeQueuedIndexesLocked(*existing) + } + cp := msg + s.messages[msg.MessageID] = &cp + s.threadMessages[msg.ThreadID] = appendUniqueString(s.threadMessages[msg.ThreadID], msg.MessageID) + if strings.EqualFold(strings.TrimSpace(msg.Status), "queued") { + s.queuedByAgent[msg.ToAgent] = appendUniqueString(s.queuedByAgent[msg.ToAgent], msg.MessageID) + s.queuedByThreadAgent[mailboxThreadAgentKey(msg.ThreadID, msg.ToAgent)] = appendUniqueString(s.queuedByThreadAgent[mailboxThreadAgentKey(msg.ThreadID, msg.ToAgent)], msg.MessageID) + } + thread := s.ensureThreadStateLocked(msg.ThreadID) + if thread != nil { + if msg.CreatedAt > thread.UpdatedAt { + thread.UpdatedAt = msg.CreatedAt + } + participants := append([]string(nil), thread.Participants...) + participants = append(participants, msg.FromAgent, msg.ToAgent) + thread.Participants = normalizeStringList(participants) + } +} + +func (s *AgentMailboxStore) ensureThreadStateLocked(threadID string) *AgentThread { + threadID = strings.TrimSpace(threadID) + if threadID == "" { + return nil + } + thread := s.threads[threadID] + if thread != nil { + return thread + } + thread = &AgentThread{ + ThreadID: threadID, + Status: "open", + CreatedAt: 1, + UpdatedAt: 1, + Participants: nil, + } + s.threads[threadID] = thread + if n := parseThreadSequence(threadID); n > s.threadSeq { + s.threadSeq = n + } + return thread +} + +func (s *AgentMailboxStore) removeQueuedIndexesLocked(msg AgentMessage) { + if !strings.EqualFold(strings.TrimSpace(msg.Status), "queued") { + return + } + s.queuedByAgent[msg.ToAgent] = removeStringValue(s.queuedByAgent[msg.ToAgent], msg.MessageID) + s.queuedByThreadAgent[mailboxThreadAgentKey(msg.ThreadID, msg.ToAgent)] = removeStringValue(s.queuedByThreadAgent[mailboxThreadAgentKey(msg.ThreadID, msg.ToAgent)], msg.MessageID) +} + +func (s *AgentMailboxStore) reconcileThreadsFromMessagesLocked() { + for _, msg := range s.messages { + s.ensureThreadStateLocked(msg.ThreadID) + } + for _, msg := range s.messages { + thread := s.threads[msg.ThreadID] + if thread == nil { + continue + } + if msg.CreatedAt > thread.UpdatedAt { + thread.UpdatedAt = msg.CreatedAt + } + participants := append([]string(nil), thread.Participants...) + participants = append(participants, msg.FromAgent, msg.ToAgent) + thread.Participants = normalizeStringList(participants) + } +} + +func (s *AgentMailboxStore) persistThreadsMetaLocked(offset int64) error { + if offset <= 0 { + size, err := jsonlog.FileSize(s.threadsPath) + if err != nil { + return err + } + offset = size + } + meta := mailboxThreadsMetaFile{ + Version: 1, + LastOffset: offset, + ThreadSeq: s.threadSeq, + UpdatedAt: time.Now().UnixMilli(), + Threads: make([]AgentThread, 0, len(s.threads)), + } + for _, thread := range s.threads { + cp := *thread + cp.Participants = append([]string(nil), thread.Participants...) + meta.Threads = append(meta.Threads, cp) + } + sort.Slice(meta.Threads, func(i, j int) bool { + if meta.Threads[i].UpdatedAt != meta.Threads[j].UpdatedAt { + return meta.Threads[i].UpdatedAt > meta.Threads[j].UpdatedAt + } + return meta.Threads[i].ThreadID < meta.Threads[j].ThreadID + }) + return jsonlog.WriteJSON(s.threadsMetaPath, meta) +} + +func (s *AgentMailboxStore) persistMessagesMetaLocked(offset int64) error { + meta := mailboxMessagesMetaFile{ + Version: 1, + LastOffset: offset, + MsgSeq: s.msgSeq, + UpdatedAt: time.Now().UnixMilli(), + Messages: make([]AgentMessage, 0, len(s.messages)), + ThreadMessages: cloneStringSliceMap(s.threadMessages), + QueuedByAgent: cloneStringSliceMap(s.queuedByAgent), + QueuedByThreadAgent: cloneStringSliceMap(s.queuedByThreadAgent), + } + for _, msg := range s.messages { + meta.Messages = append(meta.Messages, *msg) + } + sort.Slice(meta.Messages, func(i, j int) bool { + if meta.Messages[i].CreatedAt != meta.Messages[j].CreatedAt { + return meta.Messages[i].CreatedAt < meta.Messages[j].CreatedAt + } + return meta.Messages[i].MessageID < meta.Messages[j].MessageID + }) + return jsonlog.WriteJSON(s.msgsMetaPath, meta) +} + func parseThreadSequence(threadID string) int { threadID = strings.TrimSpace(threadID) if !strings.HasPrefix(threadID, "thread-") { @@ -386,3 +544,45 @@ func parseMessageSequence(messageID string) int { n, _ := strconv.Atoi(strings.TrimPrefix(messageID, "msg-")) return n } + +func mailboxThreadAgentKey(threadID, agentID string) string { + return strings.TrimSpace(threadID) + "::" + strings.TrimSpace(agentID) +} + +func cloneStringSliceMap(src map[string][]string) map[string][]string { + if len(src) == 0 { + return map[string][]string{} + } + out := make(map[string][]string, len(src)) + for key, values := range src { + out[key] = append([]string(nil), values...) + } + return out +} + +func appendUniqueString(values []string, value string) []string { + value = strings.TrimSpace(value) + if value == "" { + return values + } + for _, existing := range values { + if existing == value { + return values + } + } + return append(values, value) +} + +func removeStringValue(values []string, value string) []string { + if len(values) == 0 { + return values + } + out := values[:0] + for _, existing := range values { + if existing == value { + continue + } + out = append(out, existing) + } + return append([]string(nil), out...) +} diff --git a/pkg/tools/subagent_mailbox_test.go b/pkg/tools/subagent_mailbox_test.go new file mode 100644 index 0000000..0645cfb --- /dev/null +++ b/pkg/tools/subagent_mailbox_test.go @@ -0,0 +1,81 @@ +package tools + +import ( + "os" + "path/filepath" + "testing" +) + +func TestAgentMailboxStoreReloadsInboxIndexes(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + store := NewAgentMailboxStore(workspace) + + thread, err := store.EnsureThread(AgentThread{ + Owner: "planner", + Topic: "handoff", + CreatedAt: 10, + UpdatedAt: 10, + }) + if err != nil { + t.Fatalf("ensure thread: %v", err) + } + msg, err := store.AppendMessage(AgentMessage{ + ThreadID: thread.ThreadID, + FromAgent: "planner", + ToAgent: "worker-a", + Type: "task", + Content: "check deploy logs", + RequiresReply: true, + Status: "queued", + CreatedAt: 20, + }) + if err != nil { + t.Fatalf("append message: %v", err) + } + + inbox, err := store.Inbox("worker-a", 10) + if err != nil { + t.Fatalf("inbox: %v", err) + } + if len(inbox) != 1 || inbox[0].MessageID != msg.MessageID { + t.Fatalf("unexpected inbox: %#v", inbox) + } + + reloaded := NewAgentMailboxStore(workspace) + threadInbox, err := reloaded.ThreadInbox(thread.ThreadID, "worker-a", 10) + if err != nil { + t.Fatalf("thread inbox: %v", err) + } + if len(threadInbox) != 1 || threadInbox[0].Content != "check deploy logs" { + t.Fatalf("unexpected thread inbox: %#v", threadInbox) + } + + if _, err := reloaded.UpdateMessageStatus(msg.MessageID, "processed", 30); err != nil { + t.Fatalf("update status: %v", err) + } + inbox, err = reloaded.Inbox("worker-a", 10) + if err != nil { + t.Fatalf("inbox after status update: %v", err) + } + if len(inbox) != 0 { + t.Fatalf("expected empty inbox after status update, got %#v", inbox) + } + + runtimeDir := filepath.Join(workspace, "agents", "runtime") + if err := os.Remove(filepath.Join(runtimeDir, "threads.meta.json")); err != nil { + t.Fatalf("remove thread meta: %v", err) + } + if err := os.Remove(filepath.Join(runtimeDir, "agent_messages.meta.json")); err != nil { + t.Fatalf("remove messages meta: %v", err) + } + rebuilt := NewAgentMailboxStore(workspace) + messages, err := rebuilt.MessagesByThread(thread.ThreadID, 10) + if err != nil { + t.Fatalf("messages by thread: %v", err) + } + if len(messages) != 1 || messages[0].Status != "processed" { + t.Fatalf("unexpected rebuilt messages: %#v", messages) + } +} diff --git a/pkg/tools/subagent_profile.go b/pkg/tools/subagent_profile.go index 765c404..541399e 100644 --- a/pkg/tools/subagent_profile.go +++ b/pkg/tools/subagent_profile.go @@ -16,24 +16,25 @@ import ( ) type SubagentProfile struct { - AgentID string `json:"agent_id"` - Name string `json:"name"` - Transport string `json:"transport,omitempty"` - ParentAgentID string `json:"parent_agent_id,omitempty"` - NotifyMainPolicy string `json:"notify_main_policy,omitempty"` - Role string `json:"role,omitempty"` - SystemPromptFile string `json:"system_prompt_file,omitempty"` - ToolAllowlist []string `json:"tool_allowlist,omitempty"` - MemoryNamespace string `json:"memory_namespace,omitempty"` - MaxRetries int `json:"max_retries,omitempty"` - RetryBackoff int `json:"retry_backoff_ms,omitempty"` - TimeoutSec int `json:"timeout_sec,omitempty"` - MaxTaskChars int `json:"max_task_chars,omitempty"` - MaxResultChars int `json:"max_result_chars,omitempty"` - Status string `json:"status"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` - ManagedBy string `json:"managed_by,omitempty"` + AgentID string `json:"agent_id"` + Name string `json:"name"` + Transport string `json:"transport,omitempty"` + ParentAgentID string `json:"parent_agent_id,omitempty"` + NotifyMainPolicy string `json:"notify_main_policy,omitempty"` + Role string `json:"role,omitempty"` + SystemPromptFile string `json:"system_prompt_file,omitempty"` + ToolAllowlist []string `json:"tool_allowlist,omitempty"` + MemoryNamespace string `json:"memory_namespace,omitempty"` + MaxRetries int `json:"max_retries,omitempty"` + RetryBackoff int `json:"retry_backoff_ms,omitempty"` + TimeoutSec int `json:"timeout_sec,omitempty"` + MaxToolIterations int `json:"max_tool_iterations,omitempty"` + MaxTaskChars int `json:"max_task_chars,omitempty"` + MaxResultChars int `json:"max_result_chars,omitempty"` + Status string `json:"status"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + ManagedBy string `json:"managed_by,omitempty"` } type SubagentProfileStore struct { @@ -192,6 +193,7 @@ func normalizeSubagentProfile(in SubagentProfile) SubagentProfile { p.MaxRetries = clampInt(p.MaxRetries, 0, 8) p.RetryBackoff = clampInt(p.RetryBackoff, 500, 120000) p.TimeoutSec = clampInt(p.TimeoutSec, 0, 3600) + p.MaxToolIterations = clampInt(p.MaxToolIterations, 0, 200) p.MaxTaskChars = clampInt(p.MaxTaskChars, 0, 400000) p.MaxResultChars = clampInt(p.MaxResultChars, 0, 400000) return p @@ -336,22 +338,23 @@ func profileFromConfig(agentID string, subcfg config.SubagentConfig) SubagentPro status = "disabled" } return normalizeSubagentProfile(SubagentProfile{ - AgentID: agentID, - Name: strings.TrimSpace(subcfg.DisplayName), - Transport: strings.TrimSpace(subcfg.Transport), - ParentAgentID: strings.TrimSpace(subcfg.ParentAgentID), - NotifyMainPolicy: strings.TrimSpace(subcfg.NotifyMainPolicy), - Role: strings.TrimSpace(subcfg.Role), - SystemPromptFile: strings.TrimSpace(subcfg.SystemPromptFile), - ToolAllowlist: append([]string(nil), subcfg.Tools.Allowlist...), - MemoryNamespace: strings.TrimSpace(subcfg.MemoryNamespace), - MaxRetries: subcfg.Runtime.MaxRetries, - RetryBackoff: subcfg.Runtime.RetryBackoffMs, - TimeoutSec: subcfg.Runtime.TimeoutSec, - MaxTaskChars: subcfg.Runtime.MaxTaskChars, - MaxResultChars: subcfg.Runtime.MaxResultChars, - Status: status, - ManagedBy: "config.json", + AgentID: agentID, + Name: strings.TrimSpace(subcfg.DisplayName), + Transport: strings.TrimSpace(subcfg.Transport), + ParentAgentID: strings.TrimSpace(subcfg.ParentAgentID), + NotifyMainPolicy: strings.TrimSpace(subcfg.NotifyMainPolicy), + Role: strings.TrimSpace(subcfg.Role), + SystemPromptFile: strings.TrimSpace(subcfg.SystemPromptFile), + ToolAllowlist: append([]string(nil), subcfg.Tools.Allowlist...), + MemoryNamespace: strings.TrimSpace(subcfg.MemoryNamespace), + MaxRetries: subcfg.Runtime.MaxRetries, + RetryBackoff: subcfg.Runtime.RetryBackoffMs, + TimeoutSec: subcfg.Runtime.TimeoutSec, + MaxToolIterations: subcfg.Runtime.MaxToolIterations, + MaxTaskChars: subcfg.Runtime.MaxTaskChars, + MaxResultChars: subcfg.Runtime.MaxResultChars, + Status: status, + ManagedBy: "config.json", }) } @@ -389,11 +392,12 @@ func (t *SubagentProfileTool) Parameters() map[string]interface{} { "description": "Tool allowlist entries. Supports tool names, '*'/'all', and grouped tokens like 'group:files_read'.", "items": map[string]interface{}{"type": "string"}, }, - "max_retries": map[string]interface{}{"type": "integer", "description": "Retry limit for subagent task execution."}, - "retry_backoff_ms": map[string]interface{}{"type": "integer", "description": "Backoff between retries in milliseconds."}, - "timeout_sec": map[string]interface{}{"type": "integer", "description": "Per-attempt timeout in seconds."}, - "max_task_chars": map[string]interface{}{"type": "integer", "description": "Task input size quota (characters)."}, - "max_result_chars": map[string]interface{}{"type": "integer", "description": "Result output size quota (characters)."}, + "max_retries": map[string]interface{}{"type": "integer", "description": "Retry limit for subagent task execution."}, + "retry_backoff_ms": map[string]interface{}{"type": "integer", "description": "Backoff between retries in milliseconds."}, + "timeout_sec": map[string]interface{}{"type": "integer", "description": "Per-attempt timeout in seconds."}, + "max_tool_iterations": map[string]interface{}{"type": "integer", "description": "Independent tool-calling iteration budget for this subagent."}, + "max_task_chars": map[string]interface{}{"type": "integer", "description": "Task input size quota (characters)."}, + "max_result_chars": map[string]interface{}{"type": "integer", "description": "Result output size quota (characters)."}, }, "required": []string{"action"}, } @@ -445,19 +449,20 @@ func (t *SubagentProfileTool) Execute(ctx context.Context, args map[string]inter return "subagent profile already exists", nil } p := SubagentProfile{ - AgentID: agentID, - Name: stringArg(args, "name"), - NotifyMainPolicy: stringArg(args, "notify_main_policy"), - Role: stringArg(args, "role"), - SystemPromptFile: stringArg(args, "system_prompt_file"), - MemoryNamespace: stringArg(args, "memory_namespace"), - Status: stringArg(args, "status"), - ToolAllowlist: parseStringList(args["tool_allowlist"]), - MaxRetries: profileIntArg(args, "max_retries"), - RetryBackoff: profileIntArg(args, "retry_backoff_ms"), - TimeoutSec: profileIntArg(args, "timeout_sec"), - MaxTaskChars: profileIntArg(args, "max_task_chars"), - MaxResultChars: profileIntArg(args, "max_result_chars"), + AgentID: agentID, + Name: stringArg(args, "name"), + NotifyMainPolicy: stringArg(args, "notify_main_policy"), + Role: stringArg(args, "role"), + SystemPromptFile: stringArg(args, "system_prompt_file"), + MemoryNamespace: stringArg(args, "memory_namespace"), + Status: stringArg(args, "status"), + ToolAllowlist: parseStringList(args["tool_allowlist"]), + MaxRetries: profileIntArg(args, "max_retries"), + RetryBackoff: profileIntArg(args, "retry_backoff_ms"), + TimeoutSec: profileIntArg(args, "timeout_sec"), + MaxToolIterations: profileIntArg(args, "max_tool_iterations"), + MaxTaskChars: profileIntArg(args, "max_task_chars"), + MaxResultChars: profileIntArg(args, "max_result_chars"), } saved, err := t.store.Upsert(p) if err != nil { @@ -506,6 +511,9 @@ func (t *SubagentProfileTool) Execute(ctx context.Context, args map[string]inter if _, ok := args["timeout_sec"]; ok { next.TimeoutSec = profileIntArg(args, "timeout_sec") } + if _, ok := args["max_tool_iterations"]; ok { + next.MaxToolIterations = profileIntArg(args, "max_tool_iterations") + } if _, ok := args["max_task_chars"]; ok { next.MaxTaskChars = profileIntArg(args, "max_task_chars") } diff --git a/pkg/tools/subagent_runtime_context.go b/pkg/tools/subagent_runtime_context.go new file mode 100644 index 0000000..08b795b --- /dev/null +++ b/pkg/tools/subagent_runtime_context.go @@ -0,0 +1,51 @@ +package tools + +import "context" + +type SubagentExecutionStats struct { + Iterations int + Attempts int + Restarts int + FailureCode string +} + +type subagentExecutionStatsKey struct{} +type subagentIterationBudgetKey struct{} + +func WithSubagentExecutionStats(ctx context.Context, stats *SubagentExecutionStats) context.Context { + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, subagentExecutionStatsKey{}, stats) +} + +func RecordSubagentExecutionStats(ctx context.Context, delta SubagentExecutionStats) { + if ctx == nil { + return + } + stats, _ := ctx.Value(subagentExecutionStatsKey{}).(*SubagentExecutionStats) + if stats == nil { + return + } + stats.Iterations += delta.Iterations + stats.Attempts += delta.Attempts + stats.Restarts += delta.Restarts + if delta.FailureCode != "" { + stats.FailureCode = delta.FailureCode + } +} + +func WithSubagentIterationBudget(ctx context.Context, budget int) context.Context { + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, subagentIterationBudgetKey{}, budget) +} + +func SubagentIterationBudget(ctx context.Context) (int, bool) { + if ctx == nil { + return 0, false + } + budget, ok := ctx.Value(subagentIterationBudgetKey{}).(int) + return budget, ok && budget > 0 +} diff --git a/pkg/tools/subagent_store.go b/pkg/tools/subagent_store.go index 6005e12..1f49452 100644 --- a/pkg/tools/subagent_store.go +++ b/pkg/tools/subagent_store.go @@ -1,7 +1,6 @@ package tools import ( - "bufio" "encoding/json" "os" "path/filepath" @@ -9,24 +8,47 @@ import ( "strconv" "strings" "sync" + "time" + + "github.com/YspCoder/clawgo/pkg/jsonlog" ) type SubagentRunEvent struct { - RunID string `json:"run_id"` - AgentID string `json:"agent_id,omitempty"` - Type string `json:"type"` - Status string `json:"status,omitempty"` - Message string `json:"message,omitempty"` - RetryCount int `json:"retry_count,omitempty"` - At int64 `json:"ts"` + RunID string `json:"run_id"` + AgentID string `json:"agent_id,omitempty"` + Type string `json:"type"` + Status string `json:"status,omitempty"` + FailureCode string `json:"failure_code,omitempty"` + Message string `json:"message,omitempty"` + RetryCount int `json:"retry_count,omitempty"` + At int64 `json:"ts"` +} + +type subagentRunsMetaFile struct { + Version int `json:"version"` + LastOffset int64 `json:"last_offset,omitempty"` + NextIDSeed int `json:"next_id_seed,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + Runs []SubagentRun `json:"runs,omitempty"` +} + +type subagentEventsMetaFile struct { + Version int `json:"version"` + LastOffset int64 `json:"last_offset,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + EventsByRun map[string][]SubagentRunEvent `json:"events_by_run,omitempty"` } type SubagentRunStore struct { - dir string - runsPath string - eventsPath string - mu sync.RWMutex - runs map[string]*SubagentRun + dir string + runsPath string + eventsPath string + runsMetaPath string + eventsMetaPath string + mu sync.RWMutex + runs map[string]*SubagentRun + events map[string][]SubagentRunEvent + nextIDSeed int } func NewSubagentRunStore(workspace string) *SubagentRunStore { @@ -36,12 +58,16 @@ func NewSubagentRunStore(workspace string) *SubagentRunStore { } dir := filepath.Join(workspace, "agents", "runtime") store := &SubagentRunStore{ - dir: dir, - runsPath: filepath.Join(dir, "subagent_runs.jsonl"), - eventsPath: filepath.Join(dir, "subagent_events.jsonl"), - runs: map[string]*SubagentRun{}, + dir: dir, + runsPath: filepath.Join(dir, "subagent_runs.jsonl"), + eventsPath: filepath.Join(dir, "subagent_events.jsonl"), + runsMetaPath: filepath.Join(dir, "subagent_runs.meta.json"), + eventsMetaPath: filepath.Join(dir, "subagent_events.meta.json"), + runs: map[string]*SubagentRun{}, + events: map[string][]SubagentRunEvent{}, + nextIDSeed: 1, } - _ = os.MkdirAll(dir, 0755) + _ = os.MkdirAll(dir, 0o755) _ = store.load() return store } @@ -51,48 +77,80 @@ func (s *SubagentRunStore) load() error { defer s.mu.Unlock() s.runs = map[string]*SubagentRun{} - f, err := os.Open(s.runsPath) - if err != nil { - if os.IsNotExist(err) { - return nil - } + s.events = map[string][]SubagentRunEvent{} + s.nextIDSeed = 1 + + if err := s.loadRunsLocked(); err != nil { return err } - defer f.Close() + return s.loadEventsLocked() +} - scanner := bufio.NewScanner(f) - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 2*1024*1024) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" { - continue - } - var record RunRecord - if err := json.Unmarshal([]byte(line), &record); err == nil && strings.TrimSpace(record.ID) != "" { - run := &SubagentRun{ - ID: record.ID, - Task: record.Input, - AgentID: record.AgentID, - ThreadID: record.ThreadID, - CorrelationID: record.CorrelationID, - ParentRunID: record.ParentRunID, - Status: record.Status, - Result: record.Output, - Created: record.CreatedAt, - Updated: record.UpdatedAt, - } - s.runs[run.ID] = run - continue - } - var run SubagentRun - if err := json.Unmarshal([]byte(line), &run); err != nil { - continue - } - cp := cloneSubagentRun(&run) - s.runs[run.ID] = cp +func (s *SubagentRunStore) loadRunsLocked() error { + size, err := jsonlog.FileSize(s.runsPath) + if err != nil { + return err } - return scanner.Err() + var meta subagentRunsMetaFile + if err := jsonlog.ReadJSON(s.runsMetaPath, &meta); err == nil && meta.Version > 0 && meta.LastOffset == size { + for _, run := range meta.Runs { + cp := cloneSubagentRun(&run) + s.runs[cp.ID] = cp + } + s.nextIDSeed = meta.NextIDSeed + if s.nextIDSeed <= 0 { + s.nextIDSeed = deriveNextRunSeed(s.runs) + } + return nil + } + + if size == 0 { + return s.persistRunsMetaLocked(size) + } + if err := jsonlog.Scan(s.runsPath, func(line []byte) error { + run, ok := decodeSubagentRunLine(line) + if !ok { + return nil + } + s.runs[run.ID] = run + return nil + }); err != nil { + return err + } + s.nextIDSeed = deriveNextRunSeed(s.runs) + return s.persistRunsMetaLocked(size) +} + +func (s *SubagentRunStore) loadEventsLocked() error { + size, err := jsonlog.FileSize(s.eventsPath) + if err != nil { + return err + } + var meta subagentEventsMetaFile + if err := jsonlog.ReadJSON(s.eventsMetaPath, &meta); err == nil && meta.Version > 0 && meta.LastOffset == size { + s.events = cloneEventsByRun(meta.EventsByRun) + return nil + } + if size == 0 { + return s.persistEventsMetaLocked(size) + } + eventsByRun := map[string][]SubagentRunEvent{} + if err := jsonlog.Scan(s.eventsPath, func(line []byte) error { + evt, ok := decodeSubagentEventLine(line) + if !ok { + return nil + } + eventsByRun[evt.RunID] = append(eventsByRun[evt.RunID], evt) + return nil + }); err != nil { + return err + } + for runID, events := range eventsByRun { + sort.Slice(events, func(i, j int) bool { return events[i].At < events[j].At }) + eventsByRun[runID] = events + } + s.events = eventsByRun + return s.persistEventsMetaLocked(size) } func (s *SubagentRunStore) AppendRun(run *SubagentRun) error { @@ -100,26 +158,22 @@ func (s *SubagentRunStore) AppendRun(run *SubagentRun) error { return nil } cp := cloneSubagentRun(run) - data, err := json.Marshal(runToRunRecord(cp)) - if err != nil { - return err - } + record := runToRunRecord(cp) s.mu.Lock() defer s.mu.Unlock() - if err := os.MkdirAll(s.dir, 0755); err != nil { + if err := os.MkdirAll(s.dir, 0o755); err != nil { return err } - f, err := os.OpenFile(s.runsPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) + offset, err := jsonlog.AppendLine(s.runsPath, record) if err != nil { return err } - defer f.Close() - if _, err := f.Write(append(data, '\n')); err != nil { - return err - } s.runs[cp.ID] = cp - return nil + if next := parseSubagentSequence(cp.ID) + 1; next > s.nextIDSeed { + s.nextIDSeed = next + } + return s.persistRunsMetaLocked(offset) } func (s *SubagentRunStore) AppendEvent(evt SubagentRunEvent) error { @@ -127,32 +181,30 @@ func (s *SubagentRunStore) AppendEvent(evt SubagentRunEvent) error { return nil } record := EventRecord{ - ID: EventRecordID(evt.RunID, evt.Type, evt.At), - RunID: evt.RunID, - RequestID: evt.RunID, - AgentID: evt.AgentID, - Type: evt.Type, - Status: evt.Status, - Message: evt.Message, - RetryCount: evt.RetryCount, - At: evt.At, - } - data, err := json.Marshal(record) - if err != nil { - return err + ID: EventRecordID(evt.RunID, evt.Type, evt.At), + RunID: evt.RunID, + RequestID: evt.RunID, + AgentID: evt.AgentID, + Type: evt.Type, + Status: evt.Status, + FailureCode: evt.FailureCode, + Message: evt.Message, + RetryCount: evt.RetryCount, + At: evt.At, } + s.mu.Lock() defer s.mu.Unlock() - if err := os.MkdirAll(s.dir, 0755); err != nil { + if err := os.MkdirAll(s.dir, 0o755); err != nil { return err } - f, err := os.OpenFile(s.eventsPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) + offset, err := jsonlog.AppendLine(s.eventsPath, record) if err != nil { return err } - defer f.Close() - _, err = f.Write(append(data, '\n')) - return err + s.events[evt.RunID] = append(s.events[evt.RunID], evt) + sort.Slice(s.events[evt.RunID], func(i, j int) bool { return s.events[evt.RunID][i].At < s.events[evt.RunID][j].At }) + return s.persistEventsMetaLocked(offset) } func (s *SubagentRunStore) Get(runID string) (*SubagentRun, bool) { @@ -191,54 +243,14 @@ func (s *SubagentRunStore) Events(runID string, limit int) ([]SubagentRunEvent, if s == nil { return nil, nil } - f, err := os.Open(s.eventsPath) - if err != nil { - if os.IsNotExist(err) { - return nil, nil - } - return nil, err + s.mu.RLock() + defer s.mu.RUnlock() + items := append([]SubagentRunEvent(nil), s.events[strings.TrimSpace(runID)]...) + sort.Slice(items, func(i, j int) bool { return items[i].At < items[j].At }) + if limit > 0 && len(items) > limit { + items = items[len(items)-limit:] } - defer f.Close() - - runID = strings.TrimSpace(runID) - events := make([]SubagentRunEvent, 0) - scanner := bufio.NewScanner(f) - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 2*1024*1024) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" { - continue - } - var evt SubagentRunEvent - if err := json.Unmarshal([]byte(line), &evt); err != nil { - var record EventRecord - if err := json.Unmarshal([]byte(line), &record); err != nil { - continue - } - evt = SubagentRunEvent{ - RunID: record.RunID, - AgentID: record.AgentID, - Type: record.Type, - Status: record.Status, - Message: record.Message, - RetryCount: record.RetryCount, - At: record.At, - } - } - if evt.RunID != runID { - continue - } - events = append(events, evt) - } - if err := scanner.Err(); err != nil { - return nil, err - } - sort.Slice(events, func(i, j int) bool { return events[i].At < events[j].At }) - if limit > 0 && len(events) > limit { - events = events[len(events)-limit:] - } - return events, nil + return items, nil } func (s *SubagentRunStore) NextIDSeed() int { @@ -247,8 +259,46 @@ func (s *SubagentRunStore) NextIDSeed() int { } s.mu.RLock() defer s.mu.RUnlock() + if s.nextIDSeed <= 0 { + return 1 + } + return s.nextIDSeed +} + +func (s *SubagentRunStore) persistRunsMetaLocked(offset int64) error { + meta := subagentRunsMetaFile{ + Version: 1, + LastOffset: offset, + NextIDSeed: maxRunSeed(deriveNextRunSeed(s.runs), s.nextIDSeed, 1), + UpdatedAt: time.Now().UnixMilli(), + Runs: make([]SubagentRun, 0, len(s.runs)), + } + for _, run := range s.runs { + meta.Runs = append(meta.Runs, *cloneSubagentRun(run)) + } + sort.Slice(meta.Runs, func(i, j int) bool { + if meta.Runs[i].Created != meta.Runs[j].Created { + return meta.Runs[i].Created > meta.Runs[j].Created + } + return meta.Runs[i].ID > meta.Runs[j].ID + }) + s.nextIDSeed = meta.NextIDSeed + return jsonlog.WriteJSON(s.runsMetaPath, meta) +} + +func (s *SubagentRunStore) persistEventsMetaLocked(offset int64) error { + meta := subagentEventsMetaFile{ + Version: 1, + LastOffset: offset, + UpdatedAt: time.Now().UnixMilli(), + EventsByRun: cloneEventsByRun(s.events), + } + return jsonlog.WriteJSON(s.eventsMetaPath, meta) +} + +func deriveNextRunSeed(runs map[string]*SubagentRun) int { maxSeq := 0 - for runID := range s.runs { + for runID := range runs { if n := parseSubagentSequence(runID); n > maxSeq { maxSeq = n } @@ -268,6 +318,71 @@ func parseSubagentSequence(runID string) int { return n } +func decodeSubagentRunLine(line []byte) (*SubagentRun, bool) { + var record RunRecord + if err := json.Unmarshal(line, &record); err == nil && strings.TrimSpace(record.ID) != "" { + return &SubagentRun{ + ID: record.ID, + Task: record.Input, + AgentID: record.AgentID, + ThreadID: record.ThreadID, + CorrelationID: record.CorrelationID, + ParentRunID: record.ParentRunID, + Status: record.Status, + Result: record.Output, + Created: record.CreatedAt, + Updated: record.UpdatedAt, + }, true + } + var run SubagentRun + if err := json.Unmarshal(line, &run); err != nil || strings.TrimSpace(run.ID) == "" { + return nil, false + } + return cloneSubagentRun(&run), true +} + +func decodeSubagentEventLine(line []byte) (SubagentRunEvent, bool) { + var evt SubagentRunEvent + if err := json.Unmarshal(line, &evt); err == nil && strings.TrimSpace(evt.RunID) != "" { + return evt, true + } + var record EventRecord + if err := json.Unmarshal(line, &record); err != nil || strings.TrimSpace(record.RunID) == "" { + return SubagentRunEvent{}, false + } + return SubagentRunEvent{ + RunID: record.RunID, + AgentID: record.AgentID, + Type: record.Type, + Status: record.Status, + FailureCode: record.FailureCode, + Message: record.Message, + RetryCount: record.RetryCount, + At: record.At, + }, true +} + +func cloneEventsByRun(src map[string][]SubagentRunEvent) map[string][]SubagentRunEvent { + if len(src) == 0 { + return map[string][]SubagentRunEvent{} + } + out := make(map[string][]SubagentRunEvent, len(src)) + for key, items := range src { + out[key] = append([]SubagentRunEvent(nil), items...) + } + return out +} + +func maxRunSeed(values ...int) int { + best := 0 + for _, value := range values { + if value > best { + best = value + } + } + return best +} + func cloneSubagentRun(run *SubagentRun) *SubagentRun { if run == nil { return nil @@ -336,4 +451,3 @@ func runToRunRecord(run *SubagentRun) RunRecord { UpdatedAt: run.Updated, } } - diff --git a/pkg/tools/subagent_store_test.go b/pkg/tools/subagent_store_test.go new file mode 100644 index 0000000..9988fd3 --- /dev/null +++ b/pkg/tools/subagent_store_test.go @@ -0,0 +1,70 @@ +package tools + +import ( + "os" + "path/filepath" + "testing" +) + +func TestSubagentRunStoreReloadsFromMetaAndJSONL(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + store := NewSubagentRunStore(workspace) + run := &SubagentRun{ + ID: "subagent-0007", + Task: "review deploy plan", + AgentID: "worker-a", + Status: RuntimeStatusCompleted, + Result: "done", + Created: 10, + Updated: 20, + } + if err := store.AppendRun(run); err != nil { + t.Fatalf("append run: %v", err) + } + if err := store.AppendEvent(SubagentRunEvent{ + RunID: run.ID, + AgentID: run.AgentID, + Type: "completed", + Status: RuntimeStatusCompleted, + Message: "done", + At: 20, + }); err != nil { + t.Fatalf("append event: %v", err) + } + + reloaded := NewSubagentRunStore(workspace) + if got, ok := reloaded.Get(run.ID); !ok || got.Result != "done" { + t.Fatalf("expected run restored, got %#v ok=%v", got, ok) + } + if seed := reloaded.NextIDSeed(); seed != 8 { + t.Fatalf("expected next seed 8, got %d", seed) + } + events, err := reloaded.Events(run.ID, 10) + if err != nil { + t.Fatalf("events: %v", err) + } + if len(events) != 1 || events[0].Type != "completed" { + t.Fatalf("unexpected events: %#v", events) + } + + runtimeDir := filepath.Join(workspace, "agents", "runtime") + if err := os.Remove(filepath.Join(runtimeDir, "subagent_runs.meta.json")); err != nil { + t.Fatalf("remove run meta: %v", err) + } + if err := os.Remove(filepath.Join(runtimeDir, "subagent_events.meta.json")); err != nil { + t.Fatalf("remove event meta: %v", err) + } + rebuilt := NewSubagentRunStore(workspace) + if seed := rebuilt.NextIDSeed(); seed != 8 { + t.Fatalf("expected rebuilt next seed 8, got %d", seed) + } + events, err = rebuilt.Events(run.ID, 10) + if err != nil { + t.Fatalf("events after rebuild: %v", err) + } + if len(events) != 1 || events[0].Message != "done" { + t.Fatalf("unexpected rebuilt events: %#v", events) + } +} diff --git a/pkg/tools/tool_allowlist_groups.go b/pkg/tools/tool_allowlist_groups.go index a301c2b..715d2f5 100644 --- a/pkg/tools/tool_allowlist_groups.go +++ b/pkg/tools/tool_allowlist_groups.go @@ -27,9 +27,9 @@ var defaultToolAllowlistGroups = []ToolAllowlistGroup{ }, { Name: "memory_read", - Description: "Read-only memory tools", + Description: "Read-only memory and session recall tools", Aliases: []string{"mem_read"}, - Tools: []string{"memory_search", "memory_get"}, + Tools: []string{"memory_search", "memory_get", "session_search"}, }, { Name: "memory_write", From 79e0a48b743380a560da01ce6879143610c788b7 Mon Sep 17 00:00:00 2001 From: lpf Date: Tue, 14 Apr 2026 14:53:18 +0800 Subject: [PATCH 5/5] feat(runtime): add process watch patterns, unified backup/import, pluggable context engine, token usage, and codex device login --- cmd/cli_common.go | 1 + cmd/cmd_backup.go | 354 ++++++++++++++++++++++++++ cmd/cmd_backup_test.go | 71 ++++++ cmd/cmd_config.go | 10 +- cmd/main.go | 4 +- pkg/agent/context_engine.go | 55 ++++ pkg/agent/context_engine_test.go | 48 ++++ pkg/agent/loop.go | 192 +++++++++++--- pkg/agent/loop_usage_test.go | 26 ++ pkg/providers/oauth.go | 169 ++++++++++-- pkg/providers/oauth_test.go | 92 +++++++ pkg/tools/message_process_test.go | 62 +++++ pkg/tools/process_tool.go | 167 +++++++++++- pkg/tools/subagent.go | 6 + pkg/tools/subagent_runtime_context.go | 14 +- workspace/AGENTS.md | 21 +- workspace/SOUL.md | 10 + workspace/USER.md | 19 ++ 18 files changed, 1257 insertions(+), 64 deletions(-) create mode 100644 cmd/cmd_backup.go create mode 100644 cmd/cmd_backup_test.go create mode 100644 pkg/agent/context_engine.go create mode 100644 pkg/agent/context_engine_test.go create mode 100644 pkg/agent/loop_usage_test.go diff --git a/cmd/cli_common.go b/cmd/cli_common.go index b683190..0e108a2 100644 --- a/cmd/cli_common.go +++ b/cmd/cli_common.go @@ -97,6 +97,7 @@ func printHelp() { fmt.Println(" cron Manage scheduled tasks") fmt.Println(" channel Test and manage messaging channels") fmt.Println(" skills Manage skills (install, list, remove)") + fmt.Println(" backup Unified backup/import for config, sessions, memory, skills") if tuiEnabled { fmt.Println(" tui Chat in terminal using the gateway chat API") } diff --git a/cmd/cmd_backup.go b/cmd/cmd_backup.go new file mode 100644 index 0000000..b82da53 --- /dev/null +++ b/cmd/cmd_backup.go @@ -0,0 +1,354 @@ +package main + +import ( + "archive/zip" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "sort" + "strings" + "time" +) + +type unifiedBackupManifest struct { + Version int `json:"version"` + CreatedAt string `json:"created_at"` + Config string `json:"config"` + Workspace string `json:"workspace"` + Includes []string `json:"includes,omitempty"` +} + +func backupCmd() { + if len(os.Args) < 3 { + backupHelp() + return + } + switch strings.TrimSpace(os.Args[2]) { + case "create": + backupCreateCmd() + case "import": + backupImportCmd() + default: + fmt.Printf("Unknown backup command: %s\n", os.Args[2]) + backupHelp() + } +} + +func backupHelp() { + fmt.Println("\nBackup commands:") + fmt.Println(" create [archive.zip] Create unified backup (config + sessions + memory + skills)") + fmt.Println(" import Restore unified backup and auto-create rollback snapshot") + fmt.Println() + fmt.Println("Examples:") + fmt.Println(" clawgo backup create") + fmt.Println(" clawgo backup create /tmp/clawgo-backup.zip") + fmt.Println(" clawgo backup import /tmp/clawgo-backup.zip") +} + +func backupCreateCmd() { + cfg, err := loadConfig() + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + return + } + out := "" + if len(os.Args) >= 4 { + out = strings.TrimSpace(os.Args[3]) + } + if out == "" { + out = defaultBackupPathForConfig("clawgo-backup", getConfigPath()) + } + count, err := createUnifiedBackup(cfg.WorkspacePath(), getConfigPath(), out) + if err != nil { + fmt.Printf("Backup failed: %v\n", err) + return + } + fmt.Printf("Backup created: %s (%d files)\n", out, count) +} + +func backupImportCmd() { + if len(os.Args) < 4 { + fmt.Println("Usage: clawgo backup import ") + return + } + cfg, err := loadConfig() + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + return + } + archive := strings.TrimSpace(os.Args[3]) + rollbackPath, restored, err := importUnifiedBackup(cfg.WorkspacePath(), getConfigPath(), archive) + if err != nil { + fmt.Printf("Import failed: %v\n", err) + return + } + fmt.Printf("Import completed: %d files restored\n", restored) + fmt.Printf("Rollback snapshot: %s\n", rollbackPath) +} + +func defaultBackupPath(prefix string) string { + return defaultBackupPathForConfig(prefix, getConfigPath()) +} + +func defaultBackupPathForConfig(prefix, configPath string) string { + configPath = strings.TrimSpace(configPath) + if configPath == "" { + configPath = getConfigPath() + } + dir := filepath.Join(filepath.Dir(configPath), "backups") + _ = os.MkdirAll(dir, 0755) + name := fmt.Sprintf("%s-%s.zip", prefix, time.Now().Format("20060102-150405")) + return filepath.Join(dir, name) +} + +func createUnifiedBackup(workspacePath, configPath, archivePath string) (int, error) { + workspacePath = strings.TrimSpace(workspacePath) + configPath = strings.TrimSpace(configPath) + archivePath = strings.TrimSpace(archivePath) + if workspacePath == "" { + return 0, fmt.Errorf("workspace path is empty") + } + if configPath == "" { + return 0, fmt.Errorf("config path is empty") + } + if archivePath == "" { + return 0, fmt.Errorf("archive path is empty") + } + if err := os.MkdirAll(filepath.Dir(archivePath), 0755); err != nil { + return 0, err + } + + f, err := os.Create(archivePath) + if err != nil { + return 0, err + } + defer f.Close() + zw := zip.NewWriter(f) + defer zw.Close() + + agentsRoot := filepath.Join(filepath.Dir(workspacePath), "agents") + includes := []string{ + "config/config.json", + "workspace/MEMORY.md", + "workspace/memory/**", + "workspace/skills/**", + "agents/**", + "workspace/AGENTS.md", + "workspace/USER.md", + "workspace/SOUL.md", + } + + fileCount := 0 + seen := map[string]struct{}{} + addFile := func(src, dst string) error { + src = filepath.Clean(src) + dst = filepath.ToSlash(strings.TrimSpace(dst)) + if src == "" || dst == "" { + return nil + } + if _, ok := seen[dst]; ok { + return nil + } + info, err := os.Stat(src) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + if info.IsDir() { + return nil + } + r, err := os.Open(src) + if err != nil { + return err + } + defer r.Close() + hdr, err := zip.FileInfoHeader(info) + if err != nil { + return err + } + hdr.Name = dst + hdr.Method = zip.Deflate + w, err := zw.CreateHeader(hdr) + if err != nil { + return err + } + if _, err := io.Copy(w, r); err != nil { + return err + } + seen[dst] = struct{}{} + fileCount++ + return nil + } + addTree := func(srcDir, dstDir string) error { + info, err := os.Stat(srcDir) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + if !info.IsDir() { + return addFile(srcDir, filepath.Join(dstDir, filepath.Base(srcDir))) + } + return filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + rel, err := filepath.Rel(srcDir, path) + if err != nil { + return err + } + return addFile(path, filepath.Join(dstDir, rel)) + }) + } + + if err := addFile(configPath, "config/config.json"); err != nil { + return 0, err + } + if err := addFile(filepath.Join(workspacePath, "MEMORY.md"), "workspace/MEMORY.md"); err != nil { + return 0, err + } + if err := addTree(filepath.Join(workspacePath, "memory"), "workspace/memory"); err != nil { + return 0, err + } + if err := addTree(filepath.Join(workspacePath, "skills"), "workspace/skills"); err != nil { + return 0, err + } + if err := addTree(agentsRoot, "agents"); err != nil { + return 0, err + } + for _, name := range []string{"AGENTS.md", "USER.md", "SOUL.md"} { + if err := addFile(filepath.Join(workspacePath, name), filepath.Join("workspace", name)); err != nil { + return 0, err + } + } + + manifest := unifiedBackupManifest{ + Version: 1, + CreatedAt: time.Now().UTC().Format(time.RFC3339), + Config: filepath.Clean(configPath), + Workspace: filepath.Clean(workspacePath), + Includes: includes, + } + manifestData, _ := json.MarshalIndent(manifest, "", " ") + w, err := zw.Create("manifest.json") + if err != nil { + return 0, err + } + if _, err := w.Write(manifestData); err != nil { + return 0, err + } + return fileCount, nil +} + +func importUnifiedBackup(workspacePath, configPath, archivePath string) (string, int, error) { + workspacePath = strings.TrimSpace(workspacePath) + configPath = strings.TrimSpace(configPath) + archivePath = strings.TrimSpace(archivePath) + if workspacePath == "" || configPath == "" || archivePath == "" { + return "", 0, fmt.Errorf("invalid import paths") + } + r, err := zip.OpenReader(archivePath) + if err != nil { + return "", 0, err + } + defer r.Close() + + rollbackPath := defaultBackupPathForConfig("clawgo-rollback", configPath) + if _, err := createUnifiedBackup(workspacePath, configPath, rollbackPath); err != nil { + return "", 0, fmt.Errorf("create rollback snapshot: %w", err) + } + + tmpDir, err := os.MkdirTemp("", "clawgo-import-*") + if err != nil { + return "", 0, err + } + defer os.RemoveAll(tmpDir) + + for _, zf := range r.File { + target := filepath.Clean(filepath.Join(tmpDir, zf.Name)) + if !strings.HasPrefix(target, tmpDir+string(filepath.Separator)) && target != tmpDir { + return "", 0, fmt.Errorf("invalid zip entry path: %s", zf.Name) + } + if zf.FileInfo().IsDir() { + if err := os.MkdirAll(target, 0755); err != nil { + return "", 0, err + } + continue + } + if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil { + return "", 0, err + } + rc, err := zf.Open() + if err != nil { + return "", 0, err + } + data, readErr := io.ReadAll(rc) + _ = rc.Close() + if readErr != nil { + return "", 0, readErr + } + if err := os.WriteFile(target, data, zf.Mode()); err != nil { + return "", 0, err + } + } + + agentsRoot := filepath.Join(filepath.Dir(workspacePath), "agents") + restoreTasks := []struct { + src string + dst string + }{ + {src: filepath.Join(tmpDir, "workspace"), dst: workspacePath}, + {src: filepath.Join(tmpDir, "agents"), dst: agentsRoot}, + } + sort.SliceStable(restoreTasks, func(i, j int) bool { return restoreTasks[i].src < restoreTasks[j].src }) + restored := 0 + for _, task := range restoreTasks { + info, err := os.Stat(task.src) + if err != nil { + if os.IsNotExist(err) { + continue + } + return rollbackPath, restored, err + } + if !info.IsDir() { + continue + } + if err := os.MkdirAll(task.dst, 0755); err != nil { + return rollbackPath, restored, err + } + if err := copyDirectory(task.src, task.dst); err != nil { + return rollbackPath, restored, err + } + filepath.Walk(task.src, func(path string, info os.FileInfo, err error) error { + if err == nil && !info.IsDir() { + restored++ + } + return nil + }) + } + + importedConfig := filepath.Join(tmpDir, "config", "config.json") + if info, err := os.Stat(importedConfig); err == nil && !info.IsDir() { + data, err := os.ReadFile(importedConfig) + if err != nil { + return rollbackPath, restored, err + } + if err := os.MkdirAll(filepath.Dir(configPath), 0755); err != nil { + return rollbackPath, restored, err + } + if err := os.WriteFile(configPath, data, 0644); err != nil { + return rollbackPath, restored, err + } + restored++ + } + + return rollbackPath, restored, nil +} diff --git a/cmd/cmd_backup_test.go b/cmd/cmd_backup_test.go new file mode 100644 index 0000000..5ea3865 --- /dev/null +++ b/cmd/cmd_backup_test.go @@ -0,0 +1,71 @@ +package main + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestUnifiedBackupCreateAndImport(t *testing.T) { + t.Parallel() + + root := t.TempDir() + workspace := filepath.Join(root, "workspace") + configPath := filepath.Join(root, "config", "config.json") + agentsDir := filepath.Join(root, "agents", "main", "sessions") + skillsDir := filepath.Join(workspace, "skills", "demo") + memoryDir := filepath.Join(workspace, "memory") + if err := os.MkdirAll(agentsDir, 0755); err != nil { + t.Fatalf("mkdir agents: %v", err) + } + if err := os.MkdirAll(skillsDir, 0755); err != nil { + t.Fatalf("mkdir skills: %v", err) + } + if err := os.MkdirAll(memoryDir, 0755); err != nil { + t.Fatalf("mkdir memory: %v", err) + } + if err := os.MkdirAll(filepath.Dir(configPath), 0755); err != nil { + t.Fatalf("mkdir config: %v", err) + } + _ = os.WriteFile(configPath, []byte(`{"gateway":{"token":"abc"}}`), 0644) + _ = os.WriteFile(filepath.Join(workspace, "MEMORY.md"), []byte("long-term"), 0644) + _ = os.WriteFile(filepath.Join(memoryDir, "2026-04-14.md"), []byte("daily-note"), 0644) + _ = os.WriteFile(filepath.Join(skillsDir, "SKILL.md"), []byte("# demo"), 0644) + _ = os.WriteFile(filepath.Join(agentsDir, "main.active.jsonl"), []byte("{\"type\":\"message\"}\n"), 0644) + + archive := filepath.Join(root, "backup.zip") + files, err := createUnifiedBackup(workspace, configPath, archive) + if err != nil { + t.Fatalf("createUnifiedBackup: %v", err) + } + if files < 4 { + t.Fatalf("expected backup files >= 4, got %d", files) + } + + // Mutate files to ensure import actually restores prior state. + _ = os.WriteFile(configPath, []byte(`{"gateway":{"token":"changed"}}`), 0644) + _ = os.WriteFile(filepath.Join(workspace, "MEMORY.md"), []byte("changed-memory"), 0644) + + rollback, restored, err := importUnifiedBackup(workspace, configPath, archive) + if err != nil { + t.Fatalf("importUnifiedBackup: %v", err) + } + if restored < 4 { + t.Fatalf("expected restored files >= 4, got %d", restored) + } + if strings.TrimSpace(rollback) == "" { + t.Fatalf("expected rollback path") + } + if _, err := os.Stat(rollback); err != nil { + t.Fatalf("rollback snapshot missing: %v", err) + } + cfgData, _ := os.ReadFile(configPath) + if !strings.Contains(string(cfgData), `"abc"`) { + t.Fatalf("config not restored, got %s", string(cfgData)) + } + memData, _ := os.ReadFile(filepath.Join(workspace, "MEMORY.md")) + if strings.TrimSpace(string(memData)) != "long-term" { + t.Fatalf("memory not restored, got %s", string(memData)) + } +} diff --git a/cmd/cmd_config.go b/cmd/cmd_config.go index c76fef4..7e561f0 100644 --- a/cmd/cmd_config.go +++ b/cmd/cmd_config.go @@ -406,6 +406,12 @@ func providerLoginCmd() { fmt.Printf("Provider %s is not configured with auth=oauth/hybrid\n", providerName) os.Exit(1) } + oauthProvider := strings.ToLower(strings.TrimSpace(pc.OAuth.Provider)) + if oauthProvider == "codex" { + // Codex login is device-code only; callback/browser modes are no longer used. + manual = false + noBrowser = true + } if manual { noBrowser = true } @@ -460,8 +466,10 @@ func providerLoginCmd() { fmt.Printf("OAuth login succeeded for provider %s\n", providerName) if manual { fmt.Println("Mode: manual callback URL paste") - } else if noBrowser { + } else if noBrowser && oauthProvider != "codex" { fmt.Println("Mode: local callback listener without auto-opening browser") + } else if oauthProvider == "codex" { + fmt.Println("Mode: device-code") } if session.Email != "" { fmt.Printf("Account: %s\n", session.Email) diff --git a/cmd/main.go b/cmd/main.go index 69c85bc..2966979 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -15,7 +15,7 @@ import ( "github.com/YspCoder/clawgo/pkg/logger" ) -var version = "0.0.2" +var version = "1.2.0" var buildTime = "unknown" const logo = ">" @@ -65,6 +65,8 @@ func main() { channelCmd() case "skills": skillsCmd() + case "backup": + backupCmd() case "tui": tuiCmd() case "version", "--version", "-v": diff --git a/pkg/agent/context_engine.go b/pkg/agent/context_engine.go new file mode 100644 index 0000000..cb5b0b3 --- /dev/null +++ b/pkg/agent/context_engine.go @@ -0,0 +1,55 @@ +package agent + +import "github.com/YspCoder/clawgo/pkg/providers" + +// ContextBuildRequest defines inputs for building a provider message window. +type ContextBuildRequest struct { + History []providers.Message + Summary string + CurrentMessage string + Media []string + Channel string + ChatID string + ResponseLanguage string + MemoryNamespace string +} + +// ContextEngine allows swapping context-assembly behavior without touching AgentLoop flow. +type ContextEngine interface { + BuildMessages(req ContextBuildRequest) []providers.Message + SkillsInfo() map[string]interface{} +} + +type defaultContextEngine struct { + builder *ContextBuilder +} + +func NewDefaultContextEngine(builder *ContextBuilder) ContextEngine { + if builder == nil { + return nil + } + return &defaultContextEngine{builder: builder} +} + +func (e *defaultContextEngine) BuildMessages(req ContextBuildRequest) []providers.Message { + if e == nil || e.builder == nil { + return nil + } + return e.builder.BuildMessagesWithMemoryNamespace( + req.History, + req.Summary, + req.CurrentMessage, + req.Media, + req.Channel, + req.ChatID, + req.ResponseLanguage, + req.MemoryNamespace, + ) +} + +func (e *defaultContextEngine) SkillsInfo() map[string]interface{} { + if e == nil || e.builder == nil { + return map[string]interface{}{} + } + return e.builder.GetSkillsInfo() +} diff --git a/pkg/agent/context_engine_test.go b/pkg/agent/context_engine_test.go new file mode 100644 index 0000000..7b54dbe --- /dev/null +++ b/pkg/agent/context_engine_test.go @@ -0,0 +1,48 @@ +package agent + +import ( + "testing" + + "github.com/YspCoder/clawgo/pkg/bus" + "github.com/YspCoder/clawgo/pkg/providers" + "github.com/YspCoder/clawgo/pkg/session" +) + +type testContextEngine struct { + lastReq ContextBuildRequest + messages []providers.Message +} + +func (e *testContextEngine) BuildMessages(req ContextBuildRequest) []providers.Message { + e.lastReq = req + return append([]providers.Message(nil), e.messages...) +} + +func (e *testContextEngine) SkillsInfo() map[string]interface{} { + return map[string]interface{}{"total": 0} +} + +func TestAgentLoopUsesPluggableContextEngine(t *testing.T) { + t.Parallel() + + engine := &testContextEngine{ + messages: []providers.Message{{Role: "system", Content: "from-test-engine"}}, + } + loop := &AgentLoop{ + sessions: session.NewSessionManager(""), + contextEngine: engine, + } + msg := bus.InboundMessage{ + Channel: "cli", + ChatID: "direct", + SessionKey: "main", + Content: "hello", + } + messages, _ := loop.prepareUserMessageContext(msg, "main") + if len(messages) != 1 || messages[0].Content != "from-test-engine" { + t.Fatalf("expected custom engine output, got %#v", messages) + } + if engine.lastReq.CurrentMessage != "hello" || engine.lastReq.Channel != "cli" { + t.Fatalf("unexpected context request: %#v", engine.lastReq) + } +} diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index a619c5f..51b8ea7 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -46,6 +46,7 @@ type AgentLoop struct { maxIterations int sessions *session.SessionManager contextBuilder *ContextBuilder + contextEngine ContextEngine tools *tools.ToolRegistry compactionEnabled bool compactionTrigger int @@ -117,6 +118,18 @@ func (al *AgentLoop) SetConfigPath(path string) { al.configPath = strings.TrimSpace(path) } +func (al *AgentLoop) SetContextEngine(engine ContextEngine) { + if al == nil { + return + } + al.runMu.Lock() + defer al.runMu.Unlock() + if engine == nil && al.contextBuilder != nil { + engine = NewDefaultContextEngine(al.contextBuilder) + } + al.contextEngine = engine +} + // StartupCompactionReport provides startup memory/session maintenance stats. type StartupCompactionReport struct { TotalSessions int `json:"total_sessions"` @@ -234,6 +247,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers // Register system info tool toolsRegistry.Register(tools.NewSystemInfoTool()) + contextBuilder := NewContextBuilder(workspace, func() []string { return toolsRegistry.GetSummaries() }) loop := &AgentLoop{ bus: msgBus, cfg: cfg, @@ -244,7 +258,8 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers temperature: cfg.Agents.Defaults.Temperature, maxIterations: cfg.Agents.Defaults.MaxToolIterations, sessions: sessionsManager, - contextBuilder: NewContextBuilder(workspace, func() []string { return toolsRegistry.GetSummaries() }), + contextBuilder: contextBuilder, + contextEngine: NewDefaultContextEngine(contextBuilder), tools: toolsRegistry, compactionEnabled: cfg.Agents.Defaults.ContextCompaction.Enabled, compactionTrigger: cfg.Agents.Defaults.ContextCompaction.TriggerMessages, @@ -893,14 +908,17 @@ type llmTurnLoopConfig struct { } type llmTurnLoopResult struct { - messages []providers.Message - pendingPersist []providers.Message - finalContent string - iteration int - attemptCount int - restartCount int - failureCode string - hasToolActivity bool + messages []providers.Message + pendingPersist []providers.Message + finalContent string + iteration int + attemptCount int + restartCount int + promptTokens int + completionTokens int + totalTokens int + failureCode string + hasToolActivity bool } func logLLMTurnRequest(iteration, maxIterations int, providerName, activeModel string, messages []providers.Message, providerToolDefs []providers.ToolDefinition, maxTokens int, temperature float64) { @@ -947,6 +965,63 @@ func logLLMToolCalls(iteration int, toolCalls []providers.ToolCall) { }) } +func mergeUsageTotals(dst *llmTurnLoopResult, usage *providers.UsageInfo) { + if dst == nil || usage == nil { + return + } + prompt := usage.PromptTokens + completion := usage.CompletionTokens + total := usage.TotalTokens + if total <= 0 { + total = prompt + completion + } + dst.promptTokens += prompt + dst.completionTokens += completion + dst.totalTokens += total +} + +func estimateResponseUsage(ctx context.Context, provider providers.LLMProvider, model string, prompt []providers.Message, toolDefs []providers.ToolDefinition, response *providers.LLMResponse) *providers.UsageInfo { + if response == nil { + return nil + } + if response.Usage != nil { + return response.Usage + } + counter, ok := provider.(providers.TokenCounter) + if !ok { + return nil + } + usage, err := counter.CountTokens(ctx, prompt, toolDefs, model, nil) + if err != nil || usage == nil { + return nil + } + promptTokens := usage.TotalTokens + if promptTokens <= 0 { + promptTokens = usage.PromptTokens + } + if promptTokens <= 0 { + return nil + } + completionChars := len(strings.TrimSpace(response.Content)) + for _, tc := range response.ToolCalls { + completionChars += len(strings.TrimSpace(tc.Name)) + if tc.Arguments != nil { + if b, err := json.Marshal(tc.Arguments); err == nil { + completionChars += len(b) + } + } + } + completionTokens := completionChars / 4 + if completionTokens <= 0 && completionChars > 0 { + completionTokens = 1 + } + return &providers.UsageInfo{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } +} + func buildAssistantToolCallMessage(response *providers.LLMResponse) providers.Message { assistantMsg := providers.Message{ Role: "assistant", @@ -1195,6 +1270,7 @@ func (al *AgentLoop) runLLMTurnLoop(cfg llmTurnLoopConfig) (llmTurnLoopResult, e }) return result, fmt.Errorf("LLM call failed: %w", err) } + mergeUsageTotals(&result, estimateResponseUsage(cfg.ctx, activeProvider, activeModel, result.messages, providerToolDefs, response)) if len(response.ToolCalls) == 0 { result.finalContent = response.Content @@ -1240,16 +1316,30 @@ func (al *AgentLoop) prepareUserMessageContext(msg bus.InboundMessage, memoryNam } preferredLang, lastLang := al.sessions.GetLanguagePreferences(msg.SessionKey) responseLang := DetectResponseLanguage(msg.Content, preferredLang, lastLang) - messages := al.contextBuilder.BuildMessagesWithMemoryNamespace( - history, - summary, - msg.Content, - nil, - msg.Channel, - msg.ChatID, - responseLang, - memoryNamespace, - ) + messages := []providers.Message(nil) + if al.contextEngine != nil { + messages = al.contextEngine.BuildMessages(ContextBuildRequest{ + History: history, + Summary: summary, + CurrentMessage: msg.Content, + Channel: msg.Channel, + ChatID: msg.ChatID, + ResponseLanguage: responseLang, + MemoryNamespace: memoryNamespace, + }) + } + if len(messages) == 0 && al.contextBuilder != nil { + messages = al.contextBuilder.BuildMessagesWithMemoryNamespace( + history, + summary, + msg.Content, + nil, + msg.Channel, + msg.ChatID, + responseLang, + memoryNamespace, + ) + } return messages, responseLang } @@ -1270,15 +1360,29 @@ func (al *AgentLoop) prepareSystemMessageContext(sessionKey string, msg bus.Inbo summary := al.sessions.GetSummary(sessionKey) preferredLang, lastLang := al.sessions.GetLanguagePreferences(sessionKey) responseLang := DetectResponseLanguage(msg.Content, preferredLang, lastLang) - messages := al.contextBuilder.BuildMessages( - history, - summary, - msg.Content, - nil, - originChannel, - originChatID, - responseLang, - ) + messages := []providers.Message(nil) + if al.contextEngine != nil { + messages = al.contextEngine.BuildMessages(ContextBuildRequest{ + History: history, + Summary: summary, + CurrentMessage: msg.Content, + Channel: originChannel, + ChatID: originChatID, + ResponseLanguage: responseLang, + MemoryNamespace: "main", + }) + } + if len(messages) == 0 && al.contextBuilder != nil { + messages = al.contextBuilder.BuildMessages( + history, + summary, + msg.Content, + nil, + originChannel, + originChatID, + responseLang, + ) + } return messages, responseLang } @@ -1567,18 +1671,24 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) }) if err != nil { tools.RecordSubagentExecutionStats(ctx, tools.SubagentExecutionStats{ - Iterations: loopResult.iteration, - Attempts: loopResult.attemptCount, - Restarts: loopResult.restartCount, - FailureCode: classifyLLMFailureCode(err), + Iterations: loopResult.iteration, + Attempts: loopResult.attemptCount, + Restarts: loopResult.restartCount, + PromptTokens: loopResult.promptTokens, + CompletionTokens: loopResult.completionTokens, + TotalTokens: loopResult.totalTokens, + FailureCode: classifyLLMFailureCode(err), }) al.reopenSpecTaskOnError(specTaskRef, msg, err) return "", err } tools.RecordSubagentExecutionStats(ctx, tools.SubagentExecutionStats{ - Iterations: loopResult.iteration, - Attempts: loopResult.attemptCount, - Restarts: loopResult.restartCount, + Iterations: loopResult.iteration, + Attempts: loopResult.attemptCount, + Restarts: loopResult.restartCount, + PromptTokens: loopResult.promptTokens, + CompletionTokens: loopResult.completionTokens, + TotalTokens: loopResult.totalTokens, }) finalContent, userContent := al.finalizeUserTurnResponse(ctx, msg, responseLang, loopResult) @@ -1590,8 +1700,14 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) "sender_id": msg.SenderID, "preview": responsePreview, "iterations": loopResult.iteration, + "attempts": loopResult.attemptCount, "final_length": len(finalContent), "user_length": len(userContent), + "token_usage": map[string]int{ + "prompt": loopResult.promptTokens, + "completion": loopResult.completionTokens, + "total": loopResult.totalTokens, + }, }) al.completeSpecTaskOnSuccess(specTaskRef, msg, userContent) @@ -2547,7 +2663,11 @@ func (al *AgentLoop) GetStartupInfo() map[string]interface{} { } // Skills info - info["skills"] = al.contextBuilder.GetSkillsInfo() + if al.contextEngine != nil { + info["skills"] = al.contextEngine.SkillsInfo() + } else if al.contextBuilder != nil { + info["skills"] = al.contextBuilder.GetSkillsInfo() + } return info } diff --git a/pkg/agent/loop_usage_test.go b/pkg/agent/loop_usage_test.go new file mode 100644 index 0000000..3924471 --- /dev/null +++ b/pkg/agent/loop_usage_test.go @@ -0,0 +1,26 @@ +package agent + +import ( + "testing" + + "github.com/YspCoder/clawgo/pkg/providers" +) + +func TestMergeUsageTotals(t *testing.T) { + t.Parallel() + + var result llmTurnLoopResult + mergeUsageTotals(&result, &providers.UsageInfo{PromptTokens: 10, CompletionTokens: 4, TotalTokens: 0}) + mergeUsageTotals(&result, &providers.UsageInfo{PromptTokens: 3, CompletionTokens: 2, TotalTokens: 9}) + + if result.promptTokens != 13 { + t.Fatalf("prompt tokens = %d, want 13", result.promptTokens) + } + if result.completionTokens != 6 { + t.Fatalf("completion tokens = %d, want 6", result.completionTokens) + } + // First merge falls back to prompt+completion (14), second uses explicit total (9). + if result.totalTokens != 23 { + t.Fatalf("total tokens = %d, want 23", result.totalTokens) + } +} diff --git a/pkg/providers/oauth.go b/pkg/providers/oauth.go index 76d860a..22968bb 100644 --- a/pkg/providers/oauth.go +++ b/pkg/providers/oauth.go @@ -33,11 +33,12 @@ const ( oauthStyleJSON = "json" defaultCodexOAuthProvider = "codex" - defaultCodexAuthURL = "https://auth.openai.com/oauth/authorize" + defaultCodexAuthURL = "https://auth.openai.com/codex/device" + defaultCodexDeviceCodeURL = "https://auth.openai.com/api/accounts/deviceauth/usercode" + defaultCodexDeviceTokenPollURL = "https://auth.openai.com/api/accounts/deviceauth/token" + defaultCodexDeviceRedirectURL = "https://auth.openai.com/deviceauth/callback" defaultCodexTokenURL = "https://auth.openai.com/oauth/token" defaultCodexClientID = "app_EMoamEEZ73f0CkXaXp7hrann" - defaultCodexCallbackPort = 1455 - defaultCodexRedirectPath = "/auth/callback" defaultClaudeOAuthProvider = "claude" defaultClaudeAuthURL = "https://claude.ai/oauth/authorize" defaultClaudeTokenURL = "https://api.anthropic.com/v1/oauth/token" @@ -90,14 +91,14 @@ var ( ) var ( - defaultAntigravityClientIDValue = "1071006060591-" + "tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" + defaultAntigravityClientIDValue = "1071006060591-" + "tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" defaultAntigravityClientSecretValue = "GOCSPX-" + "K58FWR486LdLJ1mLB8sXC4z6qDAf" - defaultGeminiClientIDValue = "681255809395-" + "oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" - defaultGeminiClientSecretValue = "GOCSPX-" + "4uHgMPm-1o7Sk-geV6Cu5clXFsxl" - defaultAntigravityClientID = firstNonEmpty(strings.TrimSpace(os.Getenv("CLAWGO_ANTIGRAVITY_CLIENT_ID")), defaultAntigravityClientIDValue) - defaultAntigravityClientSecret = firstNonEmpty(strings.TrimSpace(os.Getenv("CLAWGO_ANTIGRAVITY_CLIENT_SECRET")), defaultAntigravityClientSecretValue) - defaultGeminiClientID = firstNonEmpty(strings.TrimSpace(os.Getenv("CLAWGO_GEMINI_CLIENT_ID")), defaultGeminiClientIDValue) - defaultGeminiClientSecret = firstNonEmpty(strings.TrimSpace(os.Getenv("CLAWGO_GEMINI_CLIENT_SECRET")), defaultGeminiClientSecretValue) + defaultGeminiClientIDValue = "681255809395-" + "oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" + defaultGeminiClientSecretValue = "GOCSPX-" + "4uHgMPm-1o7Sk-geV6Cu5clXFsxl" + defaultAntigravityClientID = firstNonEmpty(strings.TrimSpace(os.Getenv("CLAWGO_ANTIGRAVITY_CLIENT_ID")), defaultAntigravityClientIDValue) + defaultAntigravityClientSecret = firstNonEmpty(strings.TrimSpace(os.Getenv("CLAWGO_ANTIGRAVITY_CLIENT_SECRET")), defaultAntigravityClientSecretValue) + defaultGeminiClientID = firstNonEmpty(strings.TrimSpace(os.Getenv("CLAWGO_GEMINI_CLIENT_ID")), defaultGeminiClientIDValue) + defaultGeminiClientSecret = firstNonEmpty(strings.TrimSpace(os.Getenv("CLAWGO_GEMINI_CLIENT_SECRET")), defaultGeminiClientSecretValue) ) var ( @@ -152,6 +153,7 @@ type oauthConfig struct { AuthURL string TokenURL string DeviceCodeURL string + DeviceTokenURL string UserInfoURL string RedirectURL string RedirectPath string @@ -611,11 +613,20 @@ func resolveOAuthConfig(pc config.ProviderConfig) (oauthConfig, error) { } switch provider { case defaultCodexOAuthProvider: - cfg.CallbackPort = defaultInt(cfg.CallbackPort, defaultCodexCallbackPort) + cfg.FlowKind = oauthFlowDevice cfg.ClientID = firstNonEmpty(cfg.ClientID, defaultCodexClientID) - cfg.AuthURL = firstNonEmpty(cfg.AuthURL, defaultCodexAuthURL) + cfg.AuthURL = firstNonEmpty(defaultCodexAuthURL) cfg.TokenURL = firstNonEmpty(cfg.TokenURL, defaultCodexTokenURL) - cfg.RedirectPath = defaultCodexRedirectPath + deviceURL := strings.TrimSpace(pc.OAuth.AuthURL) + if deviceURL == "" || strings.Contains(strings.ToLower(deviceURL), "/oauth/authorize") { + deviceURL = defaultCodexDeviceCodeURL + } + cfg.DeviceCodeURL = deviceURL + if strings.Contains(strings.ToLower(deviceURL), "/usercode") { + cfg.DeviceTokenURL = strings.Replace(deviceURL, "/usercode", "/token", 1) + } + cfg.DeviceTokenURL = firstNonEmpty(cfg.DeviceTokenURL, defaultCodexDeviceTokenPollURL) + cfg.RedirectURL = firstNonEmpty(strings.TrimSpace(pc.OAuth.RedirectURL), defaultCodexDeviceRedirectURL) if len(cfg.Scopes) == 0 { cfg.Scopes = append([]string(nil), defaultCodexScopes...) } @@ -743,11 +754,15 @@ func (m *oauthManager) login(ctx context.Context, apiBase string, opts OAuthLogi if err != nil { return nil, nil, err } - fmt.Printf("Open this URL to continue OAuth login:\n%s\n", flow.AuthURL) + if m.cfg.Provider == defaultCodexOAuthProvider { + fmt.Printf("To continue Codex login:\n1) Open: %s\n2) Enter code: %s\n", flow.AuthURL, strings.TrimSpace(flow.UserCode)) + } else { + fmt.Printf("Open this URL to continue OAuth login:\n%s\n", flow.AuthURL) + } if strings.TrimSpace(flow.UserCode) != "" { fmt.Printf("User code: %s\n", flow.UserCode) } - if !opts.NoBrowser { + if !opts.NoBrowser && m.cfg.Provider != defaultCodexOAuthProvider { if err := openBrowser(flow.AuthURL); err != nil { fmt.Printf("Automatic browser open failed: %v\n", err) } @@ -2379,6 +2394,32 @@ func (m *oauthManager) startDeviceFlow(ctx context.Context, opts OAuthLoginOptio form := url.Values{} form.Set("client_id", m.cfg.ClientID) switch m.cfg.Provider { + case defaultCodexOAuthProvider: + body, _ := json.Marshal(map[string]any{"client_id": m.cfg.ClientID}) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, m.cfg.DeviceCodeURL, strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + raw, err := m.doJSONRequest(req, "oauth device request", opts.NetworkProxy) + if err != nil { + return nil, err + } + userCode := strings.TrimSpace(asString(raw["user_code"])) + deviceAuthID := strings.TrimSpace(asString(raw["device_auth_id"])) + if userCode == "" || deviceAuthID == "" { + return nil, fmt.Errorf("oauth device flow missing user_code/device_auth_id") + } + return &OAuthPendingFlow{ + Mode: oauthFlowDevice, + AuthURL: firstNonEmpty(strings.TrimSpace(m.cfg.AuthURL), defaultCodexAuthURL), + UserCode: userCode, + DeviceCode: deviceAuthID, + IntervalSec: defaultInt(asInt(raw["interval"]), 5), + ExpiresAt: deviceExpiry(asInt(raw["expires_in"])), + Instructions: "Open the verification URL, enter the user code, and approve the device login.", + }, nil case defaultQwenOAuthProvider: verifier, challenge, err := generatePKCE() if err != nil { @@ -2430,6 +2471,26 @@ func (m *oauthManager) startDeviceFlow(ctx context.Context, opts OAuthLoginOptio } } +func asInt(v any) int { + switch n := v.(type) { + case int: + return n + case int64: + return int(n) + case float64: + return int(n) + case json.Number: + if parsed, err := n.Int64(); err == nil { + return int(parsed) + } + case string: + if parsed, err := strconv.Atoi(strings.TrimSpace(n)); err == nil { + return parsed + } + } + return 0 +} + func (m *oauthManager) doFormDeviceRequest(ctx context.Context, endpoint string, form url.Values, proxyURL string) (map[string]any, error) { req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode())) if err != nil { @@ -2513,6 +2574,9 @@ func (m *oauthManager) pollDeviceToken(ctx context.Context, flow *OAuthPendingFl if flow == nil || strings.TrimSpace(flow.DeviceCode) == "" { return nil, fmt.Errorf("oauth device flow missing device code") } + if m.cfg.Provider == defaultCodexOAuthProvider { + return m.pollCodexDeviceToken(ctx, flow, proxyURL) + } interval := time.Duration(defaultInt(flow.IntervalSec, 5)) * time.Second deadline := time.Now().Add(m.cfg.DevicePollMax) if expireAt, err := time.Parse(time.RFC3339, strings.TrimSpace(flow.ExpiresAt)); err == nil && expireAt.Before(deadline) { @@ -2560,3 +2624,78 @@ func (m *oauthManager) pollDeviceToken(ctx context.Context, flow *OAuthPendingFl } } } + +func (m *oauthManager) pollCodexDeviceToken(ctx context.Context, flow *OAuthPendingFlow, proxyURL string) (*oauthSession, error) { + interval := time.Duration(defaultInt(flow.IntervalSec, 5)) * time.Second + if interval < 3*time.Second { + interval = 3 * time.Second + } + deadline := time.Now().Add(m.cfg.DevicePollMax) + if expireAt, err := time.Parse(time.RFC3339, strings.TrimSpace(flow.ExpiresAt)); err == nil && expireAt.Before(deadline) { + deadline = expireAt + } + for { + if time.Now().After(deadline) { + return nil, fmt.Errorf("oauth device flow timed out") + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(interval): + } + payload, _ := json.Marshal(map[string]any{ + "device_auth_id": strings.TrimSpace(flow.DeviceCode), + "user_code": strings.TrimSpace(flow.UserCode), + }) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, m.cfg.DeviceTokenURL, strings.NewReader(string(payload))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + client, err := m.httpClientForProxy(proxyURL) + if err != nil { + return nil, err + } + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("oauth device poll request failed: %w", err) + } + body, readErr := io.ReadAll(resp.Body) + resp.Body.Close() + if readErr != nil { + return nil, fmt.Errorf("oauth device poll read failed: %w", readErr) + } + switch resp.StatusCode { + case http.StatusOK: + var raw map[string]any + if err := json.Unmarshal(body, &raw); err != nil { + return nil, fmt.Errorf("oauth device poll decode failed: %w", err) + } + authCode := strings.TrimSpace(asString(raw["authorization_code"])) + verifier := strings.TrimSpace(asString(raw["code_verifier"])) + if authCode == "" || verifier == "" { + return nil, fmt.Errorf("oauth device auth response missing authorization_code/code_verifier") + } + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("code", authCode) + form.Set("redirect_uri", m.cfg.RedirectURL) + form.Set("client_id", m.cfg.ClientID) + form.Set("code_verifier", verifier) + tokenRaw, err := m.doFormTokenRequest(ctx, form, proxyURL) + if err != nil { + return nil, err + } + session, err := sessionFromTokenPayload(m.cfg.Provider, tokenRaw) + if err != nil { + return nil, err + } + return m.enrichSession(ctx, session) + case http.StatusForbidden, http.StatusNotFound: + continue + default: + return nil, fmt.Errorf("oauth device poll failed: status=%d body=%s", resp.StatusCode, strings.TrimSpace(string(body))) + } + } +} diff --git a/pkg/providers/oauth_test.go b/pkg/providers/oauth_test.go index 9ffc332..b17ed01 100644 --- a/pkg/providers/oauth_test.go +++ b/pkg/providers/oauth_test.go @@ -323,6 +323,7 @@ func TestResolveOAuthConfigSupportsAdditionalProviders(t *testing.T) { want string flow string }{ + {name: "codex", provider: "codex", want: "codex", flow: oauthFlowDevice}, {name: "anthropic-alias", provider: "anthropic", want: "claude", flow: oauthFlowCallback}, {name: "antigravity", provider: "antigravity", want: "antigravity", flow: oauthFlowCallback}, {name: "gemini", provider: "gemini", want: "gemini", flow: oauthFlowCallback}, @@ -951,6 +952,97 @@ func TestOAuthDeviceFlowQwenManualCompletes(t *testing.T) { } } +func TestOAuthDeviceFlowCodexStartAndComplete(t *testing.T) { + t.Parallel() + + var pollAttempts int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/accounts/deviceauth/usercode": + if got := r.Header.Get("Content-Type"); !strings.Contains(strings.ToLower(got), "application/json") { + t.Fatalf("expected json content type, got %s", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"user_code":"U-CODE","device_auth_id":"dev-auth-1","interval":0,"expires_in":60}`)) + case "/api/accounts/deviceauth/token": + attempt := atomic.AddInt32(&pollAttempts, 1) + if attempt == 1 { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{}`)) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"authorization_code":"auth-code-1","code_verifier":"verifier-1"}`)) + case "/oauth/token": + if err := r.ParseForm(); err != nil { + t.Fatalf("parse form failed: %v", err) + } + if got := r.Form.Get("grant_type"); got != "authorization_code" { + t.Fatalf("unexpected grant_type: %s", got) + } + if got := r.Form.Get("code"); got != "auth-code-1" { + t.Fatalf("unexpected code: %s", got) + } + if got := r.Form.Get("code_verifier"); got != "verifier-1" { + t.Fatalf("unexpected code_verifier: %s", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"codex-at","refresh_token":"codex-rt","expires_in":3600}`)) + case "/models": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"data":[{"id":"gpt-5.4"}]}`)) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + dir := t.TempDir() + manager, err := newOAuthManager(config.ProviderConfig{ + APIBase: server.URL, + Auth: "oauth", + OAuth: config.ProviderOAuthConfig{ + Provider: "codex", + AuthURL: server.URL + "/api/accounts/deviceauth/usercode", + TokenURL: server.URL + "/oauth/token", + RedirectURL: server.URL + "/deviceauth/callback", + CredentialFile: filepath.Join(dir, "codex.json"), + }, + }, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + defer manager.bgCancel() + + flow, err := manager.startDeviceFlow(context.Background(), OAuthLoginOptions{}) + if err != nil { + t.Fatalf("start device flow failed: %v", err) + } + if flow.Mode != oauthFlowDevice { + t.Fatalf("unexpected flow mode: %s", flow.Mode) + } + if flow.UserCode != "U-CODE" || flow.DeviceCode != "dev-auth-1" { + t.Fatalf("unexpected flow payload: %#v", flow) + } + if flow.IntervalSec < 1 { + t.Fatalf("expected normalized poll interval >=1, got %d", flow.IntervalSec) + } + + session, models, err := manager.completeDeviceFlow(context.Background(), server.URL, flow, OAuthLoginOptions{}) + if err != nil { + t.Fatalf("complete device flow failed: %v", err) + } + if session.AccessToken != "codex-at" || session.RefreshToken != "codex-rt" { + t.Fatalf("unexpected session tokens: %#v", session) + } + if atomic.LoadInt32(&pollAttempts) < 2 { + t.Fatalf("expected polling retries, got %d", pollAttempts) + } + if len(models) != 1 || models[0] != "gpt-5.4" { + t.Fatalf("unexpected models: %#v", models) + } +} + func TestHTTPProviderHybridFallsBackFromAPIKeyToOAuth(t *testing.T) { t.Parallel() diff --git a/pkg/tools/message_process_test.go b/pkg/tools/message_process_test.go index 6a1a0f0..8d5a38d 100644 --- a/pkg/tools/message_process_test.go +++ b/pkg/tools/message_process_test.go @@ -2,8 +2,10 @@ package tools import ( "context" + "encoding/json" "strings" "testing" + "time" "github.com/YspCoder/clawgo/pkg/bus" ) @@ -66,3 +68,63 @@ func TestProcessToolParsesStringIntegers(t *testing.T) { t.Fatalf("expected json list output, got %s", out) } } + +func TestProcessToolWatchPatternsMatchesLog(t *testing.T) { + t.Parallel() + + pm := NewProcessManager(t.TempDir()) + id, err := pm.Start(context.Background(), "printf 'READY\\n'; sleep 0.05", "") + if err != nil { + t.Fatalf("start failed: %v", err) + } + tool := NewProcessTool(pm) + + out, err := tool.Execute(context.Background(), map[string]interface{}{ + "action": "watch_patterns", + "session_id": id, + "patterns": []interface{}{"ready"}, + "timeout_ms": 2000, + "interval_ms": 50, + }) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + var payload map[string]interface{} + if err := json.Unmarshal([]byte(out), &payload); err != nil { + t.Fatalf("invalid json output: %v (%s)", err, out) + } + if matched, _ := payload["matched"].(bool); !matched { + t.Fatalf("expected matched response, got %v", payload) + } +} + +func TestProcessToolWatchPatternsTimesOut(t *testing.T) { + t.Parallel() + + pm := NewProcessManager(t.TempDir()) + id, err := pm.Start(context.Background(), "sleep 0.3", "") + if err != nil { + t.Fatalf("start failed: %v", err) + } + tool := NewProcessTool(pm) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + out, err := tool.Execute(ctx, map[string]interface{}{ + "action": "watch_patterns", + "session_id": id, + "patterns": "nomatch", + "timeout_ms": "120", + "interval_ms": "30", + }) + if err != nil { + t.Fatalf("execute failed: %v", err) + } + var payload map[string]interface{} + if err := json.Unmarshal([]byte(out), &payload); err != nil { + t.Fatalf("invalid json output: %v (%s)", err, out) + } + if timedOut, _ := payload["timed_out"].(bool); !timedOut { + t.Fatalf("expected timed_out=true, got %v", payload) + } +} diff --git a/pkg/tools/process_tool.go b/pkg/tools/process_tool.go index 59b62b8..92914ef 100644 --- a/pkg/tools/process_tool.go +++ b/pkg/tools/process_tool.go @@ -3,6 +3,8 @@ package tools import ( "context" "encoding/json" + "fmt" + "strings" "time" ) @@ -11,15 +13,19 @@ type ProcessTool struct{ m *ProcessManager } func NewProcessTool(m *ProcessManager) *ProcessTool { return &ProcessTool{m: m} } func (t *ProcessTool) Name() string { return "process" } func (t *ProcessTool) Description() string { - return "Manage background exec sessions: list, poll, log, kill" + return "Manage background exec sessions: list, poll, log, kill, watch_patterns" } func (t *ProcessTool) Parameters() map[string]interface{} { return map[string]interface{}{"type": "object", "properties": map[string]interface{}{ - "action": map[string]interface{}{"type": "string", "description": "list|poll|log|kill"}, - "session_id": map[string]interface{}{"type": "string"}, - "offset": map[string]interface{}{"type": "integer"}, - "limit": map[string]interface{}{"type": "integer"}, - "timeout_ms": map[string]interface{}{"type": "integer"}, + "action": map[string]interface{}{"type": "string", "description": "list|poll|log|kill|watch_patterns"}, + "session_id": map[string]interface{}{"type": "string"}, + "offset": map[string]interface{}{"type": "integer"}, + "limit": map[string]interface{}{"type": "integer"}, + "timeout_ms": map[string]interface{}{"type": "integer"}, + "interval_ms": map[string]interface{}{"type": "integer"}, + "patterns": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}}, + "case_sensitive": map[string]interface{}{"type": "boolean"}, + "alert_on_exit": map[string]interface{}{"type": "boolean"}, }, "required": []string{"action"}} } @@ -76,7 +82,156 @@ func (t *ProcessTool) Execute(ctx context.Context, args map[string]interface{}) } b, _ := json.Marshal(resp) return string(b), nil + case "watch_patterns": + patterns := MapStringListArg(args, "patterns") + if len(patterns) == 0 { + return "", fmt.Errorf("patterns is required") + } + timeout := MapIntArg(args, "timeout_ms", 30000) + if timeout < 1 { + timeout = 30000 + } + interval := MapIntArg(args, "interval_ms", 250) + if interval < 50 { + interval = 50 + } + if interval > timeout { + interval = timeout + } + off := MapIntArg(args, "offset", 0) + if off < 0 { + off = 0 + } + caseSensitive := false + if v, ok := MapBoolArg(args, "case_sensitive"); ok { + caseSensitive = v + } + alertOnExit := true + if v, ok := MapBoolArg(args, "alert_on_exit"); ok { + alertOnExit = v + } + return t.watchPatterns(ctx, sid, patterns, off, timeout, interval, caseSensitive, alertOnExit) default: return "", nil } } + +func (t *ProcessTool) watchPatterns(ctx context.Context, sid string, patterns []string, offset, timeoutMs, intervalMs int, caseSensitive, alertOnExit bool) (string, error) { + s, ok := t.m.Get(sid) + if !ok { + return "", fmt.Errorf("session not found: %s", sid) + } + type watchPattern struct { + original string + lookup string + } + normalized := make([]watchPattern, 0, len(patterns)) + for _, p := range patterns { + p = strings.TrimSpace(p) + if p == "" { + continue + } + lookup := p + if !caseSensitive { + lookup = strings.ToLower(p) + } + normalized = append(normalized, watchPattern{original: p, lookup: lookup}) + } + if len(normalized) == 0 { + return "", fmt.Errorf("patterns is required") + } + started := time.Now() + deadline := started.Add(time.Duration(timeoutMs) * time.Millisecond) + scanBuf := "" + nextOffset := offset + for { + chunk, err := t.m.Log(sid, nextOffset, 16*1024) + if err != nil { + return "", err + } + if chunk != "" { + nextOffset += len(chunk) + scanBuf += chunk + if len(scanBuf) > 24*1024 { + scanBuf = scanBuf[len(scanBuf)-24*1024:] + } + haystack := scanBuf + if !caseSensitive { + haystack = strings.ToLower(haystack) + } + for _, pattern := range normalized { + if strings.Contains(haystack, pattern.lookup) { + resp := map[string]interface{}{ + "id": s.ID, + "matched": true, + "pattern": pattern.original, + "running": processSessionRunning(s), + "next_offset": nextOffset, + "elapsed_ms": time.Since(started).Milliseconds(), + } + b, _ := json.Marshal(resp) + return string(b), nil + } + } + } + running, exitCode := processSessionState(s) + if !running { + resp := map[string]interface{}{ + "id": s.ID, + "matched": false, + "running": false, + "exit_code": exitCode, + "next_offset": nextOffset, + "elapsed_ms": time.Since(started).Milliseconds(), + } + if alertOnExit { + resp["event"] = "process_exited" + } + b, _ := json.Marshal(resp) + return string(b), nil + } + now := time.Now() + if now.After(deadline) { + resp := map[string]interface{}{ + "id": s.ID, + "matched": false, + "running": true, + "timed_out": true, + "next_offset": nextOffset, + "elapsed_ms": now.Sub(started).Milliseconds(), + } + b, _ := json.Marshal(resp) + return string(b), nil + } + wait := time.Duration(intervalMs) * time.Millisecond + if remaining := time.Until(deadline); wait > remaining { + wait = remaining + } + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-time.After(wait): + } + } +} + +func processSessionRunning(s *processSession) bool { + if s == nil { + return false + } + s.mu.RLock() + defer s.mu.RUnlock() + return s.ExitCode == nil +} + +func processSessionState(s *processSession) (running bool, exitCode interface{}) { + if s == nil { + return false, nil + } + s.mu.RLock() + defer s.mu.RUnlock() + if s.ExitCode == nil { + return true, nil + } + return false, *s.ExitCode +} diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 717d97a..1fc6057 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -37,6 +37,9 @@ type SubagentRun struct { IterationCount int `json:"iteration_count,omitempty"` AttemptCount int `json:"attempt_count,omitempty"` RestartCount int `json:"restart_count,omitempty"` + PromptTokens int `json:"prompt_tokens,omitempty"` + CompletionTokens int `json:"completion_tokens,omitempty"` + TotalTokens int `json:"total_tokens,omitempty"` LastFailureCode string `json:"last_failure_code,omitempty"` ThreadID string `json:"thread_id,omitempty"` CorrelationID string `json:"correlation_id,omitempty"` @@ -872,6 +875,9 @@ func (sm *SubagentManager) applyExecutionStats(run *SubagentRun, stats *Subagent run.IterationCount += stats.Iterations run.AttemptCount += stats.Attempts run.RestartCount += stats.Restarts + run.PromptTokens += stats.PromptTokens + run.CompletionTokens += stats.CompletionTokens + run.TotalTokens += stats.TotalTokens if strings.TrimSpace(stats.FailureCode) != "" { run.LastFailureCode = strings.TrimSpace(stats.FailureCode) } diff --git a/pkg/tools/subagent_runtime_context.go b/pkg/tools/subagent_runtime_context.go index 08b795b..bc6341d 100644 --- a/pkg/tools/subagent_runtime_context.go +++ b/pkg/tools/subagent_runtime_context.go @@ -3,10 +3,13 @@ package tools import "context" type SubagentExecutionStats struct { - Iterations int - Attempts int - Restarts int - FailureCode string + Iterations int + Attempts int + Restarts int + PromptTokens int + CompletionTokens int + TotalTokens int + FailureCode string } type subagentExecutionStatsKey struct{} @@ -30,6 +33,9 @@ func RecordSubagentExecutionStats(ctx context.Context, delta SubagentExecutionSt stats.Iterations += delta.Iterations stats.Attempts += delta.Attempts stats.Restarts += delta.Restarts + stats.PromptTokens += delta.PromptTokens + stats.CompletionTokens += delta.CompletionTokens + stats.TotalTokens += delta.TotalTokens if delta.FailureCode != "" { stats.FailureCode = delta.FailureCode } diff --git a/workspace/AGENTS.md b/workspace/AGENTS.md index 2ea7903..c87b10a 100644 --- a/workspace/AGENTS.md +++ b/workspace/AGENTS.md @@ -30,6 +30,10 @@ At the start of work, load context in this order: - Daily log: write to `memory/YYYY-MM-DD.md` - Long-term memory: write to `MEMORY.md` - Prefer short, structured notes (bullets) over long paragraphs. +- For "previous chat / last time / earlier discussion" requests: + - first use `session_search` to recover transcript evidence + - then use `memory_search` for durable preferences/decisions + - do not guess from memory when searchable history exists --- @@ -169,7 +173,7 @@ If thinking is complete but output should be suppressed, output exactly: If the user message contains any of: - `remember, 记得, 上次, 之前, 偏好, preference, todo, 待办, 决定, decision` Then: -- prioritize recalling from `MEMORY.md` and today’s log +- prioritize recalling via `session_search`, then `MEMORY.md` and today’s log - if writing memory, write short, structured bullets #### 12.4 Empty listing fallbacks @@ -199,3 +203,18 @@ If content includes any of: ### 13) Safety - No destructive actions without confirmation. - No external sending/actions unless explicitly allowed. +- For channel-facing actions (Telegram/Weixin/Feishu/etc), prefer "internal draft -> explicit send" when ambiguity exists. +- If a tool call may touch external systems, state: target, expected side effect, and rollback hint. + +--- + +### 14) Runtime Reliability Defaults +- Keep user-facing latency first: + - do not block final user response on non-critical background maintenance + - allow best-effort background retries for compaction/index maintenance +- Prefer structured failure reporting: + - classify failures (`timeout`, `stream_failed`, `retry_limit`, `context_compacted`) when available + - avoid generic "failed" messages without actionable context +- Use incremental state paths by default: + - append-only logs first + - sidecar/index as rebuildable acceleration, not source of truth diff --git a/workspace/SOUL.md b/workspace/SOUL.md index 792306a..0fd2b0f 100644 --- a/workspace/SOUL.md +++ b/workspace/SOUL.md @@ -14,12 +14,22 @@ _You're not a chatbot. You're becoming someone._ **Remember you're a guest.** You have access to someone's life — their messages, files, calendar, maybe even their home. That's intimacy. Treat it with respect. +**Prefer evidence over confidence.** When stating "done" or "safe," back it with checks (tests, logs, or verifiable state). + ## Boundaries - Private things stay private. Period. - When in doubt, ask before acting externally. - Never send half-baked replies to messaging surfaces. - You're not the user's voice — be careful in group chats. +- Never claim external side effects completed unless observed or confirmed. + +## Execution Discipline + +- Resolve ambiguity through local inspection before asking. +- Prioritize user-visible latency; move maintenance work to safe background paths when possible. +- Keep failure modes explicit: what failed, why, and what next. +- Use small, reversible changes over broad speculative rewrites. ## Vibe diff --git a/workspace/USER.md b/workspace/USER.md index 5bb7a0f..e4a3f6c 100644 --- a/workspace/USER.md +++ b/workspace/USER.md @@ -7,11 +7,30 @@ _Learn about the person you're helping. Update this as you go._ - **Pronouns:** _(optional)_ - **Timezone:** - **Notes:** +- **Primary language(s):** +- **Decision style:** _(fast default / wants options / risk-sensitive)_ +- **Tooling preference:** _(CLI-first / WebUI-first / mixed)_ +- **Release preference:** _(canary-first / direct-prod with safeguards)_ +- **Communication preference:** _(short status / detailed rationale)_ ## Context _(What do they care about? What projects are they working on? What annoys them? What makes them laugh? Build this over time.)_ +## Working Agreements + +- Record only information that improves future execution quality. +- Prefer verifiable facts over interpretations. +- When uncertain, mark as `hypothesis` and avoid persisting as hard preference. + +## Update Protocol + +- When adding a new stable preference, include: + - `source` (where it came from) + - `confidence` (`low`/`medium`/`high`) + - `last_verified` (YYYY-MM-DD) +- Remove stale preferences that have been contradicted by recent behavior. + --- The more you know, the better you can help. But remember — you're learning about a person, not building a dossier. Respect the difference.