From a9169c66ff809657b216bef4e711a345395abc20 Mon Sep 17 00:00:00 2001 From: lpf Date: Wed, 8 Apr 2026 15:25:28 +0800 Subject: [PATCH] Release v1.0.2 --- cmd/cmd_gateway.go | 1 + config.example.json | 12 + pkg/agent/loop.go | 1024 ++++++++++++++++---------- pkg/agent/loop_codex_options_test.go | 215 +++++- pkg/api/server.go | 533 +++++++++++++- pkg/api/server_test.go | 168 ++++- pkg/channels/weixin.go | 16 + pkg/config/config.go | 2 + pkg/config/normalized.go | 10 + pkg/config/normalized_test.go | 13 + pkg/config/validate.go | 3 + pkg/providers/codex_provider.go | 4 +- pkg/providers/codex_provider_test.go | 21 + pkg/providers/http_provider.go | 60 +- pkg/providers/oauth_test.go | 38 + 15 files changed, 1670 insertions(+), 450 deletions(-) diff --git a/cmd/cmd_gateway.go b/cmd/cmd_gateway.go index 5afda3a..d2adee4 100644 --- a/cmd/cmd_gateway.go +++ b/cmd/cmd_gateway.go @@ -275,6 +275,7 @@ func gatewayCmd() { registryServer.SetConfigAfterHook(func(forceRuntimeReload bool) error { return triggerReload("api", forceRuntimeReload) }) + registryServer.SetMessageBus(msgBus) if rawWeixin, ok := channelManager.GetChannel("weixin"); ok { if weixinChannel, ok := rawWeixin.(*channels.WeixinChannel); ok { weixinChannel.SetConfigPath(getConfigPath()) diff --git a/config.example.json b/config.example.json index 018b750..4eda1b1 100644 --- a/config.example.json +++ b/config.example.json @@ -178,6 +178,8 @@ "codex": { "api_base": "https://api.openai.com/v1", "models": ["gpt-5.4"], + "max_tokens": 8192, + "temperature": 0.7, "responses": { "web_search_enabled": false, "web_search_context_size": "", @@ -204,6 +206,8 @@ "gemini": { "api_base": "https://generativelanguage.googleapis.com/v1beta/openai", "models": ["gemini-2.5-pro"], + "max_tokens": 8192, + "temperature": 0.7, "responses": { "web_search_enabled": false, "web_search_context_size": "", @@ -231,6 +235,8 @@ "api_key": "sk-your-openai-api-key", "api_base": "https://api.openai.com/v1", "models": ["gpt-5.4", "gpt-5.4-mini"], + "max_tokens": 8192, + "temperature": 0.7, "responses": { "web_search_enabled": false, "web_search_context_size": "", @@ -245,6 +251,8 @@ "anthropic": { "api_base": "https://api.anthropic.com", "models": ["claude-sonnet-4-20250514"], + "max_tokens": 8192, + "temperature": 0.7, "responses": { "web_search_enabled": false, "web_search_context_size": "", @@ -270,6 +278,8 @@ "qwen": { "api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1", "models": ["qwen-max"], + "max_tokens": 8192, + "temperature": 0.7, "responses": { "web_search_enabled": false, "web_search_context_size": "", @@ -294,6 +304,8 @@ "kimi": { "api_base": "https://api.moonshot.cn/v1", "models": ["kimi-k2-0711-preview"], + "max_tokens": 8192, + "temperature": 0.7, "responses": { "web_search_enabled": false, "web_search_context_size": "", diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index f8113af..7c12080 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -40,6 +40,8 @@ type AgentLoop struct { provider providers.LLMProvider workspace string model string + maxTokens int + temperature float64 maxIterations int sessions *session.SessionManager contextBuilder *ContextBuilder @@ -56,6 +58,8 @@ type AgentLoop struct { providerNames []string providerPool map[string]providers.LLMProvider providerResponses map[string]config.ProviderResponsesConfig + providerMaxTokens map[string]int + providerTemperatures map[string]float64 telegramStreaming bool providerMu sync.RWMutex sessionProvider map[string]string @@ -222,6 +226,8 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers provider: provider, workspace: workspace, model: provider.GetDefaultModel(), + maxTokens: cfg.Agents.Defaults.MaxTokens, + temperature: cfg.Agents.Defaults.Temperature, maxIterations: cfg.Agents.Defaults.MaxToolIterations, sessions: sessionsManager, contextBuilder: NewContextBuilder(workspace, func() []string { return toolsRegistry.GetSummaries() }), @@ -237,6 +243,8 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers sessionProvider: map[string]string{}, sessionStreamed: map[string]bool{}, providerResponses: map[string]config.ProviderResponsesConfig{}, + providerMaxTokens: map[string]int{}, + providerTemperatures: map[string]float64{}, telegramStreaming: cfg.Channels.Telegram.Streaming, subagentManager: subagentManager, subagentRouter: subagentRouter, @@ -265,6 +273,8 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers loop.providerNames = append(loop.providerNames, primaryName) if pc, ok := config.ProviderConfigByName(cfg, primaryName); ok { loop.providerResponses[primaryName] = pc.Responses + loop.providerMaxTokens[primaryName] = pc.MaxTokens + loop.providerTemperatures[primaryName] = pc.Temperature } seenProviders := map[string]struct{}{primaryName: {}} providerConfigs := config.AllProviderConfigs(cfg) @@ -304,6 +314,8 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers modelName = strings.TrimSpace(pc.Models[0]) } loop.providerResponses[providerName] = pc.Responses + loop.providerMaxTokens[providerName] = pc.MaxTokens + loop.providerTemperatures[providerName] = pc.Temperature } seenProviders[providerName] = struct{}{} loop.providerNames = append(loop.providerNames, providerName) @@ -464,7 +476,7 @@ func (al *AgentLoop) buildSessionShards(ctx context.Context) []chan bus.InboundM return shards } -func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMessage, messages []providers.Message, toolDefs []providers.ToolDefinition, options map[string]interface{}, primaryErr error) (*providers.LLMResponse, string, error) { +func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMessage, messages []providers.Message, toolDefs []providers.ToolDefinition, primaryErr error) (*providers.LLMResponse, string, error) { if len(al.providerChain) <= 1 { return nil, "", primaryErr } @@ -495,8 +507,10 @@ func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMe lastErr = err continue } - resp, err := p.Chat(ctx, messages, toolDefs, candidateModel, options) + fallbackOptions := al.buildResponsesOptionsForProvider(msg.SessionKey, candidate.name, int64(al.maxTokensForProvider(candidate.name)), al.temperatureForProvider(candidate.name)) + resp, err := p.Chat(ctx, messages, toolDefs, candidateModel, fallbackOptions) if err == nil { + al.setSessionProvider(msg.SessionKey, candidate.name) logger.WarnCF("agent", logger.C0150, map[string]interface{}{"provider": candidate.name, "model": candidateModel, "ref": candidate.ref}) return resp, candidate.name, nil } @@ -550,6 +564,76 @@ func (al *AgentLoop) ensureProviderCandidate(candidate providerCandidate) (provi return created, model, nil } +func (al *AgentLoop) providerCandidateByName(name string) (providerCandidate, bool) { + if al == nil { + return providerCandidate{}, false + } + target := strings.TrimSpace(name) + if target == "" { + return providerCandidate{}, false + } + for _, candidate := range al.providerChain { + if strings.EqualFold(strings.TrimSpace(candidate.name), target) { + return candidate, true + } + } + return providerCandidate{}, false +} + +func (al *AgentLoop) defaultProviderName() string { + if al == nil || len(al.providerNames) == 0 { + return "" + } + return strings.TrimSpace(al.providerNames[0]) +} + +func (al *AgentLoop) sessionProviderName(sessionKey string) string { + name := strings.TrimSpace(al.getSessionProvider(sessionKey)) + if name == "" { + name = al.defaultProviderName() + } + return name +} + +func (al *AgentLoop) isKnownProviderName(name string) bool { + target := strings.TrimSpace(name) + if target == "" { + return false + } + for _, item := range al.providerNames { + if strings.EqualFold(strings.TrimSpace(item), target) { + return true + } + } + return false +} + +func (al *AgentLoop) activeProviderForSession(sessionKey string) (providers.LLMProvider, string, string, error) { + if al == nil { + return nil, "", "", fmt.Errorf("agent loop is nil") + } + name := al.sessionProviderName(sessionKey) + if name == "" { + return al.provider, al.model, "", nil + } + if strings.EqualFold(name, al.defaultProviderName()) { + model := strings.TrimSpace(al.model) + if model == "" && al.provider != nil { + model = strings.TrimSpace(al.provider.GetDefaultModel()) + } + return al.provider, model, name, nil + } + candidate, ok := al.providerCandidateByName(name) + if !ok { + return al.provider, al.model, name, nil + } + p, model, err := al.ensureProviderCandidate(candidate) + if err != nil { + return nil, "", name, err + } + return p, model, name, nil +} + func automaticFallbackPriority(name string) int { switch normalizeFallbackProviderName(name) { case "claude": @@ -610,10 +694,22 @@ func (al *AgentLoop) getSessionProvider(sessionKey string) string { } func (al *AgentLoop) syncSessionDefaultProvider(sessionKey string) { - if al == nil || len(al.providerNames) == 0 { + if al == nil { return } - al.setSessionProvider(sessionKey, al.providerNames[0]) + current := strings.TrimSpace(al.getSessionProvider(sessionKey)) + if current == "" { + if name := al.defaultProviderName(); name != "" { + al.setSessionProvider(sessionKey, name) + } + return + } + if al.isKnownProviderName(current) { + return + } + if name := al.defaultProviderName(); name != "" { + al.setSessionProvider(sessionKey, name) + } } func (al *AgentLoop) markSessionStreamed(sessionKey string) { @@ -707,6 +803,8 @@ func sessionShardIndex(sessionKey string, shardCount int) int { return int(h.Sum32() % uint32(shardCount)) } +var thinkTagPattern = regexp.MustCompile(`(?s).*?`) + func (al *AgentLoop) getTrigger(msg bus.InboundMessage) string { if msg.Metadata != nil { if t := strings.TrimSpace(msg.Metadata["trigger"]); t != "" { @@ -748,6 +846,413 @@ func (al *AgentLoop) shouldSuppressOutbound(msg bus.InboundMessage, response str return len(r) <= maxChars } +type llmTurnLoopConfig struct { + ctx context.Context + triggerMsg bus.InboundMessage + sessionKey string + toolChannel string + toolChatID string + messages []providers.Message + media []string + mediaItems []bus.MediaItem + enableStreaming bool + errorLogCode logger.CodeID + logDirectResponse bool +} + +type llmTurnLoopResult struct { + messages []providers.Message + pendingPersist []providers.Message + finalContent string + iteration int + hasToolActivity bool +} + +func logLLMTurnRequest(iteration, maxIterations int, providerName, activeModel string, messages []providers.Message, providerToolDefs []providers.ToolDefinition, maxTokens int, temperature float64) { + systemPromptLen := 0 + if len(messages) > 0 { + systemPromptLen = len(messages[0].Content) + } + logger.DebugCF("agent", logger.C0152, map[string]interface{}{ + "iteration": iteration, + "max": maxIterations, + "provider": providerName, + "model": activeModel, + "messages_count": len(messages), + "tools_count": len(providerToolDefs), + "max_tokens": maxTokens, + "temperature": temperature, + "system_prompt_len": systemPromptLen, + }) + if iteration == 1 { + logger.DebugCF("agent", logger.C0153, map[string]interface{}{ + "iteration": iteration, + "messages_json": formatMessagesForLog(messages), + "tools_json": formatToolsForLog(providerToolDefs), + }) + } +} + +func logLLMDirectResponse(iteration int, finalContent string) { + logger.InfoCF("agent", logger.C0156, map[string]interface{}{ + "iteration": iteration, + "content_chars": len(finalContent), + }) +} + +func logLLMToolCalls(iteration int, toolCalls []providers.ToolCall) { + toolNames := make([]string, 0, len(toolCalls)) + for _, tc := range toolCalls { + toolNames = append(toolNames, tc.Name) + } + logger.InfoCF("agent", logger.C0157, map[string]interface{}{ + "tools": toolNames, + "count": len(toolNames), + "iteration": iteration, + }) +} + +func buildAssistantToolCallMessage(response *providers.LLMResponse) providers.Message { + assistantMsg := providers.Message{ + Role: "assistant", + Content: response.Content, + } + if response == nil { + return assistantMsg + } + for _, tc := range response.ToolCalls { + argumentsJSON, _ := json.Marshal(tc.Arguments) + assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ + ID: tc.ID, + Type: "function", + Function: &providers.FunctionCall{ + Name: tc.Name, + Arguments: string(argumentsJSON), + }, + }) + } + return assistantMsg +} + +func (al *AgentLoop) executeResponseToolCalls(cfg llmTurnLoopConfig, iteration int, response *providers.LLMResponse) []providers.Message { + if response == nil || len(response.ToolCalls) == 0 { + return nil + } + results := make([]providers.Message, 0, len(response.ToolCalls)) + for _, tc := range response.ToolCalls { + argsJSON, _ := json.Marshal(tc.Arguments) + logger.InfoCF("agent", logger.C0172, map[string]interface{}{ + "tool": tc.Name, + "args": truncate(string(argsJSON), 200), + "iteration": iteration, + }) + execArgs := withToolContextArgs(tc.Name, tc.Arguments, cfg.toolChannel, cfg.toolChatID) + toolResult, toolErr := al.executeToolCall(cfg.ctx, tc.Name, execArgs, cfg.toolChannel, cfg.toolChatID) + if toolErr != nil { + toolResult = fmt.Sprintf("Error: %v", toolErr) + } + results = append(results, providers.Message{ + Role: "tool", + Content: toolResult, + ToolCallID: tc.ID, + }) + } + return results +} + +func (al *AgentLoop) requestLLMResponse(cfg llmTurnLoopConfig, activeProvider providers.LLMProvider, activeModel string, messages []providers.Message, providerToolDefs []providers.ToolDefinition, options map[string]interface{}) (*providers.LLMResponse, error) { + if cfg.enableStreaming { + if sp, ok := activeProvider.(providers.StreamingLLMProvider); ok { + streamText := "" + lastPush := time.Now().Add(-time.Second) + return sp.ChatStream(cfg.ctx, messages, providerToolDefs, activeModel, options, func(delta string) { + if strings.TrimSpace(delta) == "" { + return + } + streamText += delta + if time.Since(lastPush) < 450*time.Millisecond { + return + } + if !shouldFlushTelegramStreamSnapshot(streamText) { + return + } + lastPush = time.Now() + replyID := "" + if cfg.triggerMsg.Metadata != nil { + replyID = cfg.triggerMsg.Metadata["message_id"] + } + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: cfg.toolChannel, + ChatID: cfg.toolChatID, + Content: streamText, + Action: "stream", + ReplyToID: replyID, + }) + al.markSessionStreamed(cfg.sessionKey) + }) + } + } + return activeProvider.Chat(cfg.ctx, messages, providerToolDefs, activeModel, options) +} + +func (al *AgentLoop) runLLMTurnLoop(cfg llmTurnLoopConfig) (llmTurnLoopResult, error) { + result := llmTurnLoopResult{ + messages: append([]providers.Message(nil), cfg.messages...), + pendingPersist: make([]providers.Message, 0, 16), + } + maxAllowed := al.maxIterations + if maxAllowed < 1 { + maxAllowed = 1 + } + toolDefs := al.filteredToolDefinitionsForContext(cfg.ctx) + providerToolDefs := al.buildProviderToolDefs(toolDefs) + result.messages = injectResponsesMediaParts(result.messages, cfg.media, cfg.mediaItems) + + for result.iteration < maxAllowed { + result.iteration++ + activeProvider, activeModel, providerName, err := al.activeProviderForSession(cfg.sessionKey) + if err != nil { + logger.ErrorCF("agent", cfg.errorLogCode, map[string]interface{}{ + "iteration": result.iteration, + "error": err.Error(), + }) + return result, fmt.Errorf("resolve active provider: %w", err) + } + if activeProvider == nil { + return result, fmt.Errorf("active provider unavailable for session %s", strings.TrimSpace(cfg.sessionKey)) + } + + maxTokens := al.maxTokensForProvider(providerName) + temperature := al.temperatureForProvider(providerName) + logLLMTurnRequest(result.iteration, al.maxIterations, providerName, activeModel, result.messages, providerToolDefs, maxTokens, temperature) + + options := al.buildResponsesOptions(cfg.sessionKey, int64(maxTokens), temperature) + response, err := al.requestLLMResponse(cfg, activeProvider, activeModel, result.messages, providerToolDefs, options) + + if err != nil { + if fb, _, ferr := al.tryFallbackProviders(cfg.ctx, cfg.triggerMsg, result.messages, providerToolDefs, err); ferr == nil && fb != nil { + response = fb + err = nil + } else { + err = ferr + } + } + if err != nil { + logger.ErrorCF("agent", cfg.errorLogCode, map[string]interface{}{ + "iteration": result.iteration, + "error": err.Error(), + }) + return result, fmt.Errorf("LLM call failed: %w", err) + } + + if len(response.ToolCalls) == 0 { + result.finalContent = response.Content + if cfg.logDirectResponse { + logLLMDirectResponse(result.iteration, result.finalContent) + } + return result, nil + } + + logLLMToolCalls(result.iteration, response.ToolCalls) + + assistantMsg := buildAssistantToolCallMessage(response) + result.messages = append(result.messages, assistantMsg) + result.pendingPersist = append(result.pendingPersist, assistantMsg) + result.hasToolActivity = true + if maxAllowed < result.iteration+al.maxIterations { + maxAllowed = result.iteration + al.maxIterations + } + + for _, toolResultMsg := range al.executeResponseToolCalls(cfg, result.iteration, response) { + result.messages = append(result.messages, toolResultMsg) + result.pendingPersist = append(result.pendingPersist, toolResultMsg) + } + } + + return result, nil +} + +func (al *AgentLoop) logInboundMessageStart(msg bus.InboundMessage) { + logger.InfoCF("agent", logger.C0171, map[string]interface{}{ + "channel": msg.Channel, + "chat_id": msg.ChatID, + "sender_id": msg.SenderID, + "session_key": msg.SessionKey, + "preview": truncate(msg.Content, 80), + }) +} + +func (al *AgentLoop) prepareUserMessageContext(msg bus.InboundMessage, memoryNamespace string) ([]providers.Message, string) { + history := al.sessions.GetHistory(msg.SessionKey) + summary := al.sessions.GetSummary(msg.SessionKey) + al.sessions.AddMessage(msg.SessionKey, "user", msg.Content) + if explicitPref := ExtractLanguagePreference(msg.Content); explicitPref != "" { + al.sessions.SetPreferredLanguage(msg.SessionKey, explicitPref) + } + preferredLang, lastLang := al.sessions.GetLanguagePreferences(msg.SessionKey) + responseLang := DetectResponseLanguage(msg.Content, preferredLang, lastLang) + messages := al.contextBuilder.BuildMessagesWithMemoryNamespace( + history, + summary, + msg.Content, + nil, + msg.Channel, + msg.ChatID, + responseLang, + memoryNamespace, + ) + return messages, responseLang +} + +func (al *AgentLoop) finalizeUserMessage(sessionKey, responseLang string, pendingPersist []providers.Message, finalContent string) { + for _, persisted := range pendingPersist { + al.sessions.AddMessageFull(sessionKey, persisted) + } + al.sessions.AddMessageFull(sessionKey, providers.Message{ + Role: "assistant", + Content: finalContent, + }) + al.sessions.SetLastLanguage(sessionKey, responseLang) + al.compactSessionIfNeeded(sessionKey) + _ = al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) +} + +func (al *AgentLoop) prepareSystemMessageContext(sessionKey string, msg bus.InboundMessage, originChannel, originChatID string) ([]providers.Message, string) { + history := al.sessions.GetHistory(sessionKey) + summary := al.sessions.GetSummary(sessionKey) + preferredLang, lastLang := al.sessions.GetLanguagePreferences(sessionKey) + responseLang := DetectResponseLanguage(msg.Content, preferredLang, lastLang) + messages := al.contextBuilder.BuildMessages( + history, + summary, + msg.Content, + nil, + originChannel, + originChatID, + responseLang, + ) + return messages, responseLang +} + +func (al *AgentLoop) finalizeSystemMessage(sessionKey, responseLang string, msg bus.InboundMessage, pendingPersist []providers.Message, finalContent string) { + al.sessions.AddMessage(sessionKey, "user", fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content)) + for _, persisted := range pendingPersist { + al.sessions.AddMessageFull(sessionKey, persisted) + } + al.sessions.AddMessageFull(sessionKey, providers.Message{ + Role: "assistant", + Content: finalContent, + }) + al.sessions.SetLastLanguage(sessionKey, responseLang) + al.compactSessionIfNeeded(sessionKey) + _ = al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) +} + +func (al *AgentLoop) startSpecTaskForMessage(msg bus.InboundMessage) specCodingTaskRef { + specTaskRef := specCodingTaskRef{} + if err := al.maybeEnsureSpecCodingDocs(msg.Content); err != nil { + logger.WarnCF("agent", logger.C0172, map[string]interface{}{ + "session_key": msg.SessionKey, + "error": err.Error(), + }) + } + taskRef, err := al.maybeStartSpecCodingTask(msg.Content) + if err != nil { + logger.WarnCF("agent", logger.C0172, map[string]interface{}{ + "session_key": msg.SessionKey, + "error": err.Error(), + }) + return specTaskRef + } + return normalizeSpecCodingTaskRef(taskRef) +} + +func (al *AgentLoop) reopenSpecTaskOnError(specTaskRef specCodingTaskRef, msg bus.InboundMessage, err error) { + if specTaskRef.Summary == "" || err == nil { + return + } + if rerr := al.maybeReopenSpecCodingTask(specTaskRef, msg.Content, err.Error()); rerr != nil { + logger.WarnCF("agent", logger.C0172, map[string]interface{}{ + "session_key": msg.SessionKey, + "error": rerr.Error(), + }) + } +} + +func (al *AgentLoop) completeSpecTaskOnSuccess(specTaskRef specCodingTaskRef, msg bus.InboundMessage, output string) { + if specTaskRef.Summary == "" { + return + } + if err := al.maybeCompleteSpecCodingTask(specTaskRef, output); err != nil { + logger.WarnCF("agent", logger.C0172, map[string]interface{}{ + "session_key": msg.SessionKey, + "error": err.Error(), + }) + } +} + +func (al *AgentLoop) recoverFinalContentAfterToolCalls(ctx context.Context, sessionKey string, messages []providers.Message, hasToolActivity bool) string { + if !hasToolActivity { + return "" + } + activeProvider, activeModel, providerName, err := al.activeProviderForSession(sessionKey) + if err != nil { + logger.WarnCF("agent", logger.C0172, map[string]interface{}{ + "session_key": sessionKey, + "error": err.Error(), + }) + return "" + } + if activeProvider == nil { + return "" + } + options := al.buildResponsesOptionsForProvider(sessionKey, providerName, int64(al.maxTokensForProvider(providerName)), 0.2) + forced, ferr := activeProvider.Chat(ctx, messages, nil, activeModel, options) + if ferr != nil || forced == nil { + if ferr != nil { + logger.WarnCF("agent", logger.C0172, map[string]interface{}{ + "session_key": sessionKey, + "error": ferr.Error(), + }) + } + return "" + } + return forced.Content +} + +func sanitizeUserVisibleContent(finalContent string, iteration int) string { + userContent := thinkTagPattern.ReplaceAllString(finalContent, "") + if userContent == "" && finalContent != "" && iteration == 1 { + return "Thinking process completed." + } + return userContent +} + +func (al *AgentLoop) finalizeUserTurnResponse(ctx context.Context, msg bus.InboundMessage, responseLang string, loopResult llmTurnLoopResult) (string, string) { + finalContent := loopResult.finalContent + if finalContent == "" { + if recovered := al.recoverFinalContentAfterToolCalls(ctx, msg.SessionKey, loopResult.messages, loopResult.hasToolActivity); recovered != "" { + finalContent = recovered + } + } + userContent := sanitizeUserVisibleContent(finalContent, loopResult.iteration) + al.finalizeUserMessage(msg.SessionKey, responseLang, loopResult.pendingPersist, userContent) + return finalContent, userContent +} + +func (al *AgentLoop) maybeHandleAutoRoute(ctx context.Context, msg bus.InboundMessage, specTaskRef specCodingTaskRef) (string, error, bool) { + routed, ok, routeErr := al.maybeAutoRoute(ctx, msg) + if !ok { + return "", nil, false + } + if routeErr != nil { + al.reopenSpecTaskOnError(specTaskRef, msg, routeErr) + return routed, routeErr, true + } + al.completeSpecTaskOnSuccess(specTaskRef, msg, routed) + return routed, nil, true +} + func loadHeartbeatAckToken(workspace string) string { workspace = strings.TrimSpace(workspace) if workspace == "" { @@ -879,282 +1384,37 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) } defer release() al.syncSessionDefaultProvider(msg.SessionKey) - // Add message preview to log - preview := truncate(msg.Content, 80) - logger.InfoCF("agent", logger.C0171, - map[string]interface{}{ - "channel": msg.Channel, - "chat_id": msg.ChatID, - "sender_id": msg.SenderID, - "session_key": msg.SessionKey, - "preview": preview, - }) + al.logInboundMessageStart(msg) // Route system messages to processSystemMessage if msg.Channel == "system" { return al.processSystemMessage(ctx, msg) } - specTaskRef := specCodingTaskRef{} - if err := al.maybeEnsureSpecCodingDocs(msg.Content); err != nil { - logger.WarnCF("agent", logger.C0172, map[string]interface{}{ - "session_key": msg.SessionKey, - "error": err.Error(), - }) - } - if taskRef, err := al.maybeStartSpecCodingTask(msg.Content); err != nil { - logger.WarnCF("agent", logger.C0172, map[string]interface{}{ - "session_key": msg.SessionKey, - "error": err.Error(), - }) - } else { - specTaskRef = normalizeSpecCodingTaskRef(taskRef) - } - if routed, ok, routeErr := al.maybeAutoRoute(ctx, msg); ok { - if routeErr != nil && specTaskRef.Summary != "" { - if err := al.maybeReopenSpecCodingTask(specTaskRef, msg.Content, routeErr.Error()); err != nil { - logger.WarnCF("agent", logger.C0172, map[string]interface{}{ - "session_key": msg.SessionKey, - "error": err.Error(), - }) - } - } - if routeErr == nil && specTaskRef.Summary != "" { - if err := al.maybeCompleteSpecCodingTask(specTaskRef, routed); err != nil { - logger.WarnCF("agent", logger.C0172, map[string]interface{}{ - "session_key": msg.SessionKey, - "error": err.Error(), - }) - } - } + specTaskRef := al.startSpecTaskForMessage(msg) + if routed, routeErr, handled := al.maybeHandleAutoRoute(ctx, msg, specTaskRef); handled { return routed, routeErr } - history := al.sessions.GetHistory(msg.SessionKey) - summary := al.sessions.GetSummary(msg.SessionKey) - if explicitPref := ExtractLanguagePreference(msg.Content); explicitPref != "" { - al.sessions.SetPreferredLanguage(msg.SessionKey, explicitPref) - } - preferredLang, lastLang := al.sessions.GetLanguagePreferences(msg.SessionKey) - responseLang := DetectResponseLanguage(msg.Content, preferredLang, lastLang) + messages, responseLang := al.prepareUserMessageContext(msg, memoryNamespace) - messages := al.contextBuilder.BuildMessagesWithMemoryNamespace( - history, - summary, - msg.Content, - nil, - msg.Channel, - msg.ChatID, - responseLang, - memoryNamespace, - ) - - iteration := 0 - var finalContent string - hasToolActivity := false - lastToolOutputs := make([]string, 0, 4) - maxAllowed := al.maxIterations - if maxAllowed < 1 { - maxAllowed = 1 - } - for iteration < maxAllowed { - iteration++ - - logger.DebugCF("agent", logger.C0151, - map[string]interface{}{ - "iteration": iteration, - "max": al.maxIterations, - }) - - toolDefs := al.filteredToolDefinitionsForContext(ctx) - providerToolDefs := al.buildProviderToolDefs(toolDefs) - - // Log LLM request details - logger.DebugCF("agent", logger.C0152, - map[string]interface{}{ - "iteration": iteration, - "model": al.model, - "messages_count": len(messages), - "tools_count": len(providerToolDefs), - "max_tokens": 8192, - "temperature": 0.7, - "system_prompt_len": len(messages[0].Content), - }) - - // Log full messages (detailed) - logger.DebugCF("agent", logger.C0153, - map[string]interface{}{ - "iteration": iteration, - "messages_json": formatMessagesForLog(messages), - "tools_json": formatToolsForLog(providerToolDefs), - }) - - messages = injectResponsesMediaParts(messages, msg.Media, msg.MediaItems) - options := al.buildResponsesOptions(msg.SessionKey, 8192, 0.7) - var response *providers.LLMResponse - var err error - if msg.Channel == "telegram" && al.telegramStreaming { - if sp, ok := al.provider.(providers.StreamingLLMProvider); ok { - streamText := "" - lastPush := time.Now().Add(-time.Second) - response, err = sp.ChatStream(ctx, messages, providerToolDefs, al.model, options, func(delta string) { - if strings.TrimSpace(delta) == "" { - return - } - streamText += delta - if time.Since(lastPush) < 450*time.Millisecond { - return - } - if !shouldFlushTelegramStreamSnapshot(streamText) { - return - } - lastPush = time.Now() - replyID := "" - if msg.Metadata != nil { - replyID = msg.Metadata["message_id"] - } - // Stream with formatted rendering once snapshot is syntactically safe. - al.bus.PublishOutbound(bus.OutboundMessage{Channel: msg.Channel, ChatID: msg.ChatID, Content: streamText, Action: "stream", ReplyToID: replyID}) - al.markSessionStreamed(msg.SessionKey) - }) - } else { - response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, options) - } - } else { - response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, options) - } - - if err != nil { - if fb, _, ferr := al.tryFallbackProviders(ctx, msg, messages, providerToolDefs, options, err); ferr == nil && fb != nil { - response = fb - err = nil - } else { - err = ferr - } - } - if err != nil { - logger.ErrorCF("agent", logger.C0155, - map[string]interface{}{ - "iteration": iteration, - "error": err.Error(), - }) - if specTaskRef.Summary != "" { - if rerr := al.maybeReopenSpecCodingTask(specTaskRef, msg.Content, err.Error()); rerr != nil { - logger.WarnCF("agent", logger.C0172, map[string]interface{}{ - "session_key": msg.SessionKey, - "error": rerr.Error(), - }) - } - } - return "", fmt.Errorf("LLM call failed: %w", err) - } - - if len(response.ToolCalls) == 0 { - finalContent = response.Content - logger.InfoCF("agent", logger.C0156, - map[string]interface{}{ - "iteration": iteration, - "content_chars": len(finalContent), - }) - break - } - - toolNames := make([]string, 0, len(response.ToolCalls)) - for _, tc := range response.ToolCalls { - toolNames = append(toolNames, tc.Name) - } - logger.InfoCF("agent", logger.C0157, - map[string]interface{}{ - "tools": toolNames, - "count": len(toolNames), - "iteration": iteration, - }) - - assistantMsg := providers.Message{ - Role: "assistant", - Content: response.Content, - } - - for _, tc := range response.ToolCalls { - argumentsJSON, _ := json.Marshal(tc.Arguments) - assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ - ID: tc.ID, - Type: "function", - Function: &providers.FunctionCall{ - Name: tc.Name, - Arguments: string(argumentsJSON), - }, - }) - } - messages = append(messages, assistantMsg) - // Persist assistant message with tool calls. - al.sessions.AddMessageFull(msg.SessionKey, assistantMsg) - - hasToolActivity = true - // Extend rolling window as long as tools keep chaining. - if maxAllowed < iteration+al.maxIterations { - maxAllowed = iteration + al.maxIterations - } - for _, tc := range response.ToolCalls { - // Log tool call with arguments preview - argsJSON, _ := json.Marshal(tc.Arguments) - argsPreview := truncate(string(argsJSON), 200) - logger.InfoCF("agent", logger.C0172, - map[string]interface{}{ - "tool": tc.Name, - "args": argsPreview, - "iteration": iteration, - }) - - execArgs := withToolContextArgs(tc.Name, tc.Arguments, msg.Channel, msg.ChatID) - result, err := al.executeToolCall(ctx, tc.Name, execArgs, msg.Channel, msg.ChatID) - if err != nil { - result = fmt.Sprintf("Error: %v", err) - } - if len(lastToolOutputs) < 4 { - lastToolOutputs = append(lastToolOutputs, fmt.Sprintf("%s: %s", tc.Name, truncate(strings.ReplaceAll(result, "\n", " "), 180))) - } - toolResultMsg := providers.Message{ - Role: "tool", - Content: result, - ToolCallID: tc.ID, - } - messages = append(messages, toolResultMsg) - // Persist tool result message. - al.sessions.AddMessageFull(msg.SessionKey, toolResultMsg) - } - } - - if finalContent == "" && hasToolActivity { - forced, ferr := al.provider.Chat(ctx, messages, nil, al.model, map[string]interface{}{"max_tokens": 8192, "temperature": 0.2}) - if ferr == nil && forced != nil && forced.Content != "" { - finalContent = forced.Content - } - } - - // Filter out ... content from user-facing response - // Keep full content in debug logs if needed, but remove from final output - re := regexp.MustCompile(`(?s).*?`) - userContent := re.ReplaceAllString(finalContent, "") - if userContent == "" && finalContent != "" { - // If only thoughts were present, maybe provide a generic "Done" or keep something? - // For now, let's assume thoughts are auxiliary and empty response is okay if tools did work. - // If no tools ran and only thoughts, user might be confused. - if iteration == 1 { - userContent = "Thinking process completed." - } - } - - al.sessions.AddMessage(msg.SessionKey, "user", msg.Content) - - // Persist full assistant response (including reasoning/tool flow outcomes when present). - al.sessions.AddMessageFull(msg.SessionKey, providers.Message{ - Role: "assistant", - Content: userContent, + loopResult, err := al.runLLMTurnLoop(llmTurnLoopConfig{ + ctx: ctx, + triggerMsg: msg, + sessionKey: msg.SessionKey, + toolChannel: msg.Channel, + toolChatID: msg.ChatID, + messages: messages, + media: msg.Media, + mediaItems: msg.MediaItems, + enableStreaming: msg.Channel == "telegram" && al.telegramStreaming, + errorLogCode: logger.C0155, + logDirectResponse: true, }) - al.sessions.SetLastLanguage(msg.SessionKey, responseLang) - al.compactSessionIfNeeded(msg.SessionKey) - - al.sessions.Save(al.sessions.GetOrCreate(msg.SessionKey)) + if err != nil { + al.reopenSpecTaskOnError(specTaskRef, msg, err) + return "", err + } + finalContent, userContent := al.finalizeUserTurnResponse(ctx, msg, responseLang, loopResult) // Log response preview (original content) responsePreview := truncate(finalContent, 120) @@ -1163,19 +1423,12 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) "channel": msg.Channel, "sender_id": msg.SenderID, "preview": responsePreview, - "iterations": iteration, + "iterations": loopResult.iteration, "final_length": len(finalContent), "user_length": len(userContent), }) - if specTaskRef.Summary != "" { - if err := al.maybeCompleteSpecCodingTask(specTaskRef, userContent); err != nil { - logger.WarnCF("agent", logger.C0172, map[string]interface{}{ - "session_key": msg.SessionKey, - "error": err.Error(), - }) - } - } + al.completeSpecTaskOnSuccess(specTaskRef, msg, userContent) al.appendDailySummaryLog(msg, userContent) return userContent, nil } @@ -1327,130 +1580,30 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe // Use the origin session for context sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID) - // Build messages with the announce content - history := al.sessions.GetHistory(sessionKey) - summary := al.sessions.GetSummary(sessionKey) - preferredLang, lastLang := al.sessions.GetLanguagePreferences(sessionKey) - responseLang := DetectResponseLanguage(msg.Content, preferredLang, lastLang) - messages := al.contextBuilder.BuildMessages( - history, - summary, - msg.Content, - nil, - originChannel, - originChatID, - responseLang, - ) + messages, responseLang := al.prepareSystemMessageContext(sessionKey, msg, originChannel, originChatID) - iteration := 0 - var finalContent string - - for iteration < al.maxIterations { - iteration++ - - toolDefs := al.filteredToolDefinitionsForContext(ctx) - providerToolDefs := al.buildProviderToolDefs(toolDefs) - - // Log LLM request details - logger.DebugCF("agent", logger.C0152, - map[string]interface{}{ - "iteration": iteration, - "model": al.model, - "messages_count": len(messages), - "tools_count": len(providerToolDefs), - "max_tokens": 8192, - "temperature": 0.7, - "system_prompt_len": len(messages[0].Content), - }) - - // Log full messages (detailed) - logger.DebugCF("agent", logger.C0153, - map[string]interface{}{ - "iteration": iteration, - "messages_json": formatMessagesForLog(messages), - "tools_json": formatToolsForLog(providerToolDefs), - }) - - options := al.buildResponsesOptions(sessionKey, 8192, 0.7) - response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, options) - - if err != nil { - if fb, _, ferr := al.tryFallbackProviders(ctx, msg, messages, providerToolDefs, options, err); ferr == nil && fb != nil { - response = fb - err = nil - } else { - err = ferr - } - } - if err != nil { - logger.ErrorCF("agent", logger.C0162, - map[string]interface{}{ - "iteration": iteration, - "error": err.Error(), - }) - return "", fmt.Errorf("LLM call failed: %w", err) - } - - if len(response.ToolCalls) == 0 { - finalContent = response.Content - break - } - - assistantMsg := providers.Message{ - Role: "assistant", - Content: response.Content, - } - - for _, tc := range response.ToolCalls { - argumentsJSON, _ := json.Marshal(tc.Arguments) - assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ - ID: tc.ID, - Type: "function", - Function: &providers.FunctionCall{ - Name: tc.Name, - Arguments: string(argumentsJSON), - }, - }) - } - messages = append(messages, assistantMsg) - // Persist assistant message with tool calls. - al.sessions.AddMessageFull(sessionKey, assistantMsg) - - for _, tc := range response.ToolCalls { - execArgs := withToolContextArgs(tc.Name, tc.Arguments, originChannel, originChatID) - result, err := al.executeToolCall(ctx, tc.Name, execArgs, originChannel, originChatID) - if err != nil { - result = fmt.Sprintf("Error: %v", err) - } - - toolResultMsg := providers.Message{ - Role: "tool", - Content: result, - ToolCallID: tc.ID, - } - messages = append(messages, toolResultMsg) - // Persist tool result message. - al.sessions.AddMessageFull(sessionKey, toolResultMsg) - } + loopResult, err := al.runLLMTurnLoop(llmTurnLoopConfig{ + ctx: ctx, + triggerMsg: msg, + sessionKey: sessionKey, + toolChannel: originChannel, + toolChatID: originChatID, + messages: messages, + errorLogCode: logger.C0162, + logDirectResponse: false, + }) + if err != nil { + return "", err } + iteration := loopResult.iteration + finalContent := loopResult.finalContent + pendingPersist := loopResult.pendingPersist if finalContent == "" { finalContent = "Background task completed." } - // Save to session with system message marker - al.sessions.AddMessage(sessionKey, "user", fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content)) - - // If finalContent has no tool calls (last LLM turn is direct text), - // earlier steps were already persisted in-loop; this stores the final reply. - al.sessions.AddMessageFull(sessionKey, providers.Message{ - Role: "assistant", - Content: finalContent, - }) - al.sessions.SetLastLanguage(sessionKey, responseLang) - al.compactSessionIfNeeded(sessionKey) - - al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) + al.finalizeSystemMessage(sessionKey, responseLang, msg, pendingPersist, finalContent) logger.InfoCF("agent", logger.C0163, map[string]interface{}{ @@ -1564,16 +1717,30 @@ func filterToolDefinitionsByContext(ctx context.Context, toolDefs []map[string]i } func (al *AgentLoop) buildResponsesOptions(sessionKey string, maxTokens int64, temperature float64) map[string]interface{} { + providerName := strings.TrimSpace(al.getSessionProvider(sessionKey)) + if providerName == "" && len(al.providerNames) > 0 { + providerName = al.providerNames[0] + } + return al.buildResponsesOptionsForProvider(sessionKey, providerName, maxTokens, temperature) +} + +func (al *AgentLoop) buildResponsesOptionsForProvider(sessionKey, providerName string, maxTokens int64, temperature float64) map[string]interface{} { + if maxTokens <= 0 { + maxTokens = int64(al.maxTokensForProvider(providerName)) + } + if math.IsNaN(temperature) { + temperature = al.temperatureForProvider(providerName) + } options := map[string]interface{}{ "max_tokens": maxTokens, "temperature": temperature, } - if strings.EqualFold(strings.TrimSpace(al.getSessionProvider(sessionKey)), "codex") { + if strings.EqualFold(strings.TrimSpace(providerName), "codex") { if key := strings.TrimSpace(sessionKey); key != "" { options["codex_execution_session"] = key } } - responsesCfg := al.responsesConfigForSession(sessionKey) + responsesCfg := al.responsesConfigForProvider(providerName) responseTools := make([]map[string]interface{}, 0, 2) if responsesCfg.WebSearchEnabled { webTool := map[string]interface{}{"type": "web_search"} @@ -1604,6 +1771,60 @@ func (al *AgentLoop) buildResponsesOptions(sessionKey string, maxTokens int64, t return options } +func (al *AgentLoop) maxTokensForProvider(name string) int { + if al == nil { + return 8192 + } + providerName := strings.TrimSpace(name) + if providerName != "" { + if limit, ok := al.providerMaxTokens[providerName]; ok && limit > 0 { + return limit + } + } + if al.maxTokens > 0 { + return al.maxTokens + } + return 8192 +} + +func (al *AgentLoop) maxTokensForSession(sessionKey string) int { + if al == nil { + return 8192 + } + name := strings.TrimSpace(al.getSessionProvider(sessionKey)) + if name == "" && len(al.providerNames) > 0 { + name = al.providerNames[0] + } + return al.maxTokensForProvider(name) +} + +func (al *AgentLoop) temperatureForProvider(name string) float64 { + if al == nil { + return 0.7 + } + providerName := strings.TrimSpace(name) + if providerName != "" { + if value, ok := al.providerTemperatures[providerName]; ok && value != 0 { + return value + } + } + if al.temperature != 0 { + return al.temperature + } + return 0.7 +} + +func (al *AgentLoop) temperatureForSession(sessionKey string) float64 { + if al == nil { + return 0.7 + } + name := strings.TrimSpace(al.getSessionProvider(sessionKey)) + if name == "" && len(al.providerNames) > 0 { + name = al.providerNames[0] + } + return al.temperatureForProvider(name) +} + func (al *AgentLoop) responsesConfigForSession(sessionKey string) config.ProviderResponsesConfig { if al == nil { return config.ProviderResponsesConfig{} @@ -1612,6 +1833,13 @@ func (al *AgentLoop) responsesConfigForSession(sessionKey string) config.Provide if name == "" && len(al.providerNames) > 0 { name = al.providerNames[0] } + return al.responsesConfigForProvider(name) +} + +func (al *AgentLoop) responsesConfigForProvider(name string) config.ProviderResponsesConfig { + if al == nil { + return config.ProviderResponsesConfig{} + } if name == "" { return config.ProviderResponsesConfig{} } diff --git a/pkg/agent/loop_codex_options_test.go b/pkg/agent/loop_codex_options_test.go index 5ca8b3d..517f302 100644 --- a/pkg/agent/loop_codex_options_test.go +++ b/pkg/agent/loop_codex_options_test.go @@ -1,6 +1,63 @@ package agent -import "testing" +import ( + "context" + "errors" + "testing" + + "github.com/YspCoder/clawgo/pkg/bus" + "github.com/YspCoder/clawgo/pkg/config" + "github.com/YspCoder/clawgo/pkg/providers" +) + +type fallbackTestProvider struct { + response *providers.LLMResponse + err error + options map[string]interface{} +} + +func (p *fallbackTestProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + p.options = map[string]interface{}{} + for k, v := range options { + p.options[k] = v + } + if p.err != nil { + return nil, p.err + } + return p.response, nil +} + +func (p *fallbackTestProvider) GetDefaultModel() string { return "fallback-model" } + +type sequenceProvider struct { + responses []*providers.LLMResponse + errs []error +} + +func (p *sequenceProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + if len(p.responses) == 0 && len(p.errs) == 0 { + return &providers.LLMResponse{Content: "ok", FinishReason: "stop"}, nil + } + resp := (*providers.LLMResponse)(nil) + if len(p.responses) > 0 { + resp = p.responses[0] + p.responses = p.responses[1:] + } + var err error + if len(p.errs) > 0 { + err = p.errs[0] + p.errs = p.errs[1:] + } + if err != nil { + return nil, err + } + if resp == nil { + return &providers.LLMResponse{Content: "ok", FinishReason: "stop"}, nil + } + return resp, nil +} + +func (p *sequenceProvider) GetDefaultModel() string { return "sequence-model" } func TestBuildResponsesOptionsAddsCodexExecutionSession(t *testing.T) { loop := &AgentLoop{ @@ -42,3 +99,159 @@ func TestSyncSessionDefaultProviderOverridesStaleSessionProvider(t *testing.T) { t.Fatalf("expected stale session provider to be replaced with current default, got %q", got) } } + +func TestSyncSessionDefaultProviderKeepsKnownSessionProvider(t *testing.T) { + loop := &AgentLoop{ + providerNames: []string{"openai", "claude"}, + sessionProvider: map[string]string{ + "chat-1": "claude", + }, + } + + loop.syncSessionDefaultProvider("chat-1") + + if got := loop.getSessionProvider("chat-1"); got != "claude" { + t.Fatalf("expected valid session provider to be preserved, got %q", got) + } +} + +func TestMaxTokensForSessionUsesProviderOverride(t *testing.T) { + loop := &AgentLoop{ + maxTokens: 4096, + providerNames: []string{"openai"}, + sessionProvider: map[string]string{ + "chat-1": "claude", + }, + providerMaxTokens: map[string]int{ + "claude": 16384, + }, + } + + if got := loop.maxTokensForSession("chat-1"); got != 16384 { + t.Fatalf("expected provider max_tokens override, got %d", got) + } +} + +func TestMaxTokensForSessionFallsBackToAgentDefault(t *testing.T) { + loop := &AgentLoop{ + maxTokens: 4096, + providerNames: []string{"openai"}, + sessionProvider: map[string]string{}, + providerMaxTokens: map[string]int{}, + } + + if got := loop.maxTokensForSession("chat-1"); got != 4096 { + t.Fatalf("expected fallback to agent default max_tokens, got %d", got) + } +} + +func TestTemperatureForSessionUsesProviderOverride(t *testing.T) { + loop := &AgentLoop{ + temperature: 0.7, + providerNames: []string{"openai"}, + sessionProvider: map[string]string{ + "chat-1": "claude", + }, + providerTemperatures: map[string]float64{ + "claude": 0.15, + }, + } + + if got := loop.temperatureForSession("chat-1"); got != 0.15 { + t.Fatalf("expected provider temperature override, got %v", got) + } +} + +func TestTemperatureForSessionFallsBackToAgentDefault(t *testing.T) { + loop := &AgentLoop{ + temperature: 0.7, + providerNames: []string{"openai"}, + sessionProvider: map[string]string{}, + providerTemperatures: map[string]float64{}, + } + + if got := loop.temperatureForSession("chat-1"); got != 0.7 { + t.Fatalf("expected fallback to agent default temperature, got %v", got) + } +} + +func TestTryFallbackProvidersUsesFallbackProviderOptionsAndPersistsSelection(t *testing.T) { + fallback := &fallbackTestProvider{ + response: &providers.LLMResponse{Content: "fallback", FinishReason: "stop"}, + } + loop := &AgentLoop{ + maxTokens: 4096, + temperature: 0.7, + providerNames: []string{"openai", "claude"}, + sessionProvider: map[string]string{"chat-1": "openai"}, + providerPool: map[string]providers.LLMProvider{"claude": fallback}, + providerChain: []providerCandidate{{name: "openai", model: "gpt-a"}, {name: "claude", model: "claude-b"}}, + providerMaxTokens: map[string]int{"claude": 16384}, + providerTemperatures: map[string]float64{"claude": 0.15}, + providerResponses: map[string]config.ProviderResponsesConfig{ + "claude": {WebSearchEnabled: true}, + }, + } + + resp, providerName, err := loop.tryFallbackProviders(context.Background(), bus.InboundMessage{SessionKey: "chat-1"}, nil, nil, errors.New("primary failed")) + if err != nil { + t.Fatalf("expected fallback success, got %v", err) + } + if resp == nil || resp.Content != "fallback" { + t.Fatalf("unexpected fallback response: %#v", resp) + } + if providerName != "claude" { + t.Fatalf("expected provider claude, got %q", providerName) + } + if got := loop.getSessionProvider("chat-1"); got != "claude" { + t.Fatalf("expected session provider to switch to fallback provider, got %q", got) + } + if got := fallback.options["max_tokens"]; got != int64(16384) { + t.Fatalf("expected fallback max_tokens 16384, got %#v", got) + } + if got := fallback.options["temperature"]; got != 0.15 { + t.Fatalf("expected fallback temperature 0.15, got %#v", got) + } + if _, ok := fallback.options["responses_tools"]; !ok { + t.Fatalf("expected fallback responses_tools to be populated") + } +} + +func TestProcessMessageDoesNotPersistPartialAssistantToolHistoryOnFailure(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = t.TempDir() + cfg.Agents.Defaults.MaxToolIterations = 2 + + provider := &sequenceProvider{ + responses: []*providers.LLMResponse{ + { + Content: "", + ToolCalls: []providers.ToolCall{ + {ID: "tool-1", Name: "read_file", Arguments: map[string]interface{}{"path": "missing.txt"}}, + }, + FinishReason: "tool_calls", + }, + }, + errs: []error{nil, errors.New("second pass failed")}, + } + + loop := NewAgentLoop(cfg, bus.NewMessageBus(), provider, nil) + _, err := loop.processMessage(context.Background(), bus.InboundMessage{ + Channel: "cli", + ChatID: "direct", + SenderID: "user", + SessionKey: "cli:direct", + Content: "read file", + }) + if err == nil { + t.Fatalf("expected processMessage error") + } + + history := loop.sessions.GetHistory("cli:direct") + if len(history) != 1 { + t.Fatalf("expected only user message persisted on failure, got %d entries: %#v", len(history), history) + } + if history[0].Role != "user" || history[0].Content != "read file" { + t.Fatalf("unexpected persisted history: %#v", history) + } +} diff --git a/pkg/api/server.go b/pkg/api/server.go index 46aaa1c..fba097d 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -25,6 +25,7 @@ import ( "sync" "time" + "github.com/YspCoder/clawgo/pkg/bus" "github.com/YspCoder/clawgo/pkg/channels" cfgpkg "github.com/YspCoder/clawgo/pkg/config" "github.com/YspCoder/clawgo/pkg/providers" @@ -50,6 +51,7 @@ type Server struct { onConfigAfter func(forceRuntimeReload bool) error onCron func(action string, args map[string]interface{}) (interface{}, error) onToolsCatalog func() interface{} + messageBus *bus.MessageBus weixinChannel *channels.WeixinChannel oauthFlowMu sync.Mutex oauthFlows map[string]*providers.OAuthPendingFlow @@ -57,6 +59,15 @@ type Server struct { extraRoutes map[string]http.Handler eventSubsMu sync.Mutex eventSubs map[*websocket.Conn]struct{} + draftMu sync.RWMutex + channelDrafts channelDraftStore +} + +type channelDraftStore struct { + Weixin *cfgpkg.WeixinConfig + Telegram *cfgpkg.TelegramConfig + Feishu *cfgpkg.FeishuConfig + weixinRuntime *channels.WeixinChannel } func NewServer(host string, port int, token string) *Server { @@ -86,6 +97,7 @@ func (s *Server) SetChatHandler(fn func(ctx context.Context, sessionKey, content func (s *Server) SetChatHistoryHandler(fn func(sessionKey string) []map[string]interface{}) { s.onChatHistory = fn } +func (s *Server) SetMessageBus(mb *bus.MessageBus) { s.messageBus = mb } func (s *Server) SetConfigAfterHook(fn func(forceRuntimeReload bool) error) { s.onConfigAfter = fn } func (s *Server) SetCronHandler(fn func(action string, args map[string]interface{}) (interface{}, error)) { s.onCron = fn @@ -117,6 +129,356 @@ func (s *Server) SetWeixinChannel(ch *channels.WeixinChannel) { } } +func cloneWeixinConfig(cfg cfgpkg.WeixinConfig) cfgpkg.WeixinConfig { + cp := cfg + cp.AllowFrom = append([]string(nil), cfg.AllowFrom...) + cp.Accounts = append([]cfgpkg.WeixinAccountConfig(nil), cfg.Accounts...) + return cp +} + +func cloneTelegramConfig(cfg cfgpkg.TelegramConfig) cfgpkg.TelegramConfig { + cp := cfg + cp.AllowFrom = append([]string(nil), cfg.AllowFrom...) + cp.AllowChats = append([]string(nil), cfg.AllowChats...) + return cp +} + +func cloneFeishuConfig(cfg cfgpkg.FeishuConfig) cfgpkg.FeishuConfig { + cp := cfg + cp.AllowFrom = append([]string(nil), cfg.AllowFrom...) + cp.AllowChats = append([]string(nil), cfg.AllowChats...) + return cp +} + +func validChannelDraftName(name string) bool { + switch strings.ToLower(strings.TrimSpace(name)) { + case "weixin", "telegram", "feishu": + return true + default: + return false + } +} + +func decodeMergedJSON[T any](current T, raw json.RawMessage) (T, error) { + out := current + if len(raw) == 0 || string(raw) == "null" { + return out, nil + } + baseBytes, err := json.Marshal(current) + if err != nil { + return out, err + } + merged := map[string]interface{}{} + if err := json.Unmarshal(baseBytes, &merged); err != nil { + return out, err + } + patch := map[string]interface{}{} + if err := json.Unmarshal(raw, &patch); err != nil { + return out, err + } + merged = mergeJSONMap(merged, patch) + mergedBytes, err := json.Marshal(merged) + if err != nil { + return out, err + } + if err := json.Unmarshal(mergedBytes, &out); err != nil { + return out, err + } + return out, nil +} + +func (s *Server) syncWeixinDraftLocked() { + if s.channelDrafts.Weixin == nil || s.channelDrafts.weixinRuntime == nil { + return + } + snapshot := s.channelDrafts.weixinRuntime.SnapshotConfig() + s.channelDrafts.Weixin = &snapshot +} + +func (s *Server) replaceWeixinDraftRuntimeLocked(cfg *cfgpkg.WeixinConfig) error { + if s.channelDrafts.weixinRuntime != nil { + _ = s.channelDrafts.weixinRuntime.Stop(context.Background()) + s.channelDrafts.weixinRuntime = nil + } + if cfg == nil || !cfg.Enabled { + return nil + } + if s.messageBus == nil { + return fmt.Errorf("message bus not configured") + } + ch, err := channels.NewWeixinChannel(cloneWeixinConfig(*cfg), s.messageBus) + if err != nil { + return err + } + if err := ch.Start(context.Background()); err != nil { + return err + } + s.channelDrafts.weixinRuntime = ch + return nil +} + +func (s *Server) clearChannelDraftsLocked() { + if s.channelDrafts.weixinRuntime != nil { + _ = s.channelDrafts.weixinRuntime.Stop(context.Background()) + } + s.channelDrafts = channelDraftStore{} +} + +func (s *Server) clearChannelDrafts() { + s.draftMu.Lock() + defer s.draftMu.Unlock() + s.clearChannelDraftsLocked() +} + +func (s *Server) effectiveWeixinRuntime(persisted cfgpkg.WeixinConfig) (cfgpkg.WeixinConfig, *channels.WeixinChannel, bool) { + s.draftMu.Lock() + defer s.draftMu.Unlock() + if s.channelDrafts.Weixin != nil { + s.syncWeixinDraftLocked() + effective := cloneWeixinConfig(*s.channelDrafts.Weixin) + return effective, s.channelDrafts.weixinRuntime, true + } + return cloneWeixinConfig(persisted), s.weixinChannel, false +} + +func (s *Server) currentChannelDraftPayload(cfg *cfgpkg.Config, channel string) map[string]interface{} { + channel = strings.ToLower(strings.TrimSpace(channel)) + payload := map[string]interface{}{ + "ok": true, + "channel": channel, + } + s.draftMu.Lock() + defer s.draftMu.Unlock() + switch channel { + case "weixin": + persisted := cloneWeixinConfig(cfg.Channels.Weixin) + var draft interface{} + effective := persisted + dirty := s.channelDrafts.Weixin != nil + if dirty { + s.syncWeixinDraftLocked() + effective = cloneWeixinConfig(*s.channelDrafts.Weixin) + draft = effective + } + payload["persisted"] = persisted + payload["draft"] = draft + payload["effective"] = effective + payload["dirty"] = dirty + payload["runtime_enabled"] = s.channelDrafts.weixinRuntime != nil && s.channelDrafts.weixinRuntime.IsRunning() + case "telegram": + persisted := cloneTelegramConfig(cfg.Channels.Telegram) + var draft interface{} + effective := persisted + dirty := s.channelDrafts.Telegram != nil + if dirty { + effective = cloneTelegramConfig(*s.channelDrafts.Telegram) + draft = effective + } + payload["persisted"] = persisted + payload["draft"] = draft + payload["effective"] = effective + payload["dirty"] = dirty + case "feishu": + persisted := cloneFeishuConfig(cfg.Channels.Feishu) + var draft interface{} + effective := persisted + dirty := s.channelDrafts.Feishu != nil + if dirty { + effective = cloneFeishuConfig(*s.channelDrafts.Feishu) + draft = effective + } + payload["persisted"] = persisted + payload["draft"] = draft + payload["effective"] = effective + payload["dirty"] = dirty + } + return payload +} + +func (s *Server) applyChannelDrafts(cfg *cfgpkg.Config) { + if cfg == nil { + return + } + s.draftMu.Lock() + defer s.draftMu.Unlock() + s.syncWeixinDraftLocked() + if s.channelDrafts.Weixin != nil { + cfg.Channels.Weixin = cloneWeixinConfig(*s.channelDrafts.Weixin) + } + if s.channelDrafts.Telegram != nil { + cfg.Channels.Telegram = cloneTelegramConfig(*s.channelDrafts.Telegram) + } + if s.channelDrafts.Feishu != nil { + cfg.Channels.Feishu = cloneFeishuConfig(*s.channelDrafts.Feishu) + } +} + +func (s *Server) handleWebUIChannelDraft(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + cfg, err := s.loadConfig() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + switch r.Method { + case http.MethodGet: + channel := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("channel"))) + if channel == "" { + writeJSON(w, map[string]interface{}{ + "ok": true, + "channels": map[string]interface{}{ + "weixin": s.currentChannelDraftPayload(cfg, "weixin"), + "telegram": s.currentChannelDraftPayload(cfg, "telegram"), + "feishu": s.currentChannelDraftPayload(cfg, "feishu"), + }, + }) + return + } + if !validChannelDraftName(channel) { + http.Error(w, "unsupported channel", http.StatusBadRequest) + return + } + writeJSON(w, s.currentChannelDraftPayload(cfg, channel)) + case http.MethodPost: + var body struct { + Channel string `json:"channel"` + Config json.RawMessage `json:"config"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + channel := strings.ToLower(strings.TrimSpace(body.Channel)) + if !validChannelDraftName(channel) { + http.Error(w, "unsupported channel", http.StatusBadRequest) + return + } + s.draftMu.Lock() + switch channel { + case "weixin": + current := cfg.Channels.Weixin + if s.channelDrafts.Weixin != nil { + s.syncWeixinDraftLocked() + current = cloneWeixinConfig(*s.channelDrafts.Weixin) + } + next, err := decodeMergedJSON(current, body.Config) + if err != nil { + s.draftMu.Unlock() + http.Error(w, "invalid weixin config", http.StatusBadRequest) + return + } + next = cloneWeixinConfig(next) + if err := s.replaceWeixinDraftRuntimeLocked(&next); err != nil { + s.draftMu.Unlock() + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + s.channelDrafts.Weixin = &next + case "telegram": + current := cfg.Channels.Telegram + if s.channelDrafts.Telegram != nil { + current = cloneTelegramConfig(*s.channelDrafts.Telegram) + } + next, err := decodeMergedJSON(current, body.Config) + if err != nil { + s.draftMu.Unlock() + http.Error(w, "invalid telegram config", http.StatusBadRequest) + return + } + next = cloneTelegramConfig(next) + s.channelDrafts.Telegram = &next + case "feishu": + current := cfg.Channels.Feishu + if s.channelDrafts.Feishu != nil { + current = cloneFeishuConfig(*s.channelDrafts.Feishu) + } + next, err := decodeMergedJSON(current, body.Config) + if err != nil { + s.draftMu.Unlock() + http.Error(w, "invalid feishu config", http.StatusBadRequest) + return + } + next = cloneFeishuConfig(next) + s.channelDrafts.Feishu = &next + } + s.draftMu.Unlock() + s.broadcastEvent(map[string]interface{}{ + "type": "channel_draft_changed", + "channel": channel, + }) + writeJSON(w, s.currentChannelDraftPayload(cfg, channel)) + case http.MethodDelete: + channel := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("channel"))) + s.draftMu.Lock() + if channel == "" { + s.clearChannelDraftsLocked() + s.draftMu.Unlock() + writeJSON(w, map[string]interface{}{"ok": true, "cleared": "all"}) + return + } + if !validChannelDraftName(channel) { + s.draftMu.Unlock() + http.Error(w, "unsupported channel", http.StatusBadRequest) + return + } + switch channel { + case "weixin": + if s.channelDrafts.weixinRuntime != nil { + _ = s.channelDrafts.weixinRuntime.Stop(context.Background()) + s.channelDrafts.weixinRuntime = nil + } + s.channelDrafts.Weixin = nil + case "telegram": + s.channelDrafts.Telegram = nil + case "feishu": + s.channelDrafts.Feishu = nil + } + s.draftMu.Unlock() + s.broadcastEvent(map[string]interface{}{ + "type": "channel_draft_changed", + "channel": channel, + }) + writeJSON(w, map[string]interface{}{"ok": true, "channel": channel, "cleared": true}) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleWebUIChannelDraftCommit(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + cfg, err := s.loadConfig() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + s.applyChannelDrafts(cfg) + if err := s.persistWebUIConfig(cfg); err != nil { + var validationErr *configValidationError + if errors.As(err, &validationErr) { + writeJSONStatus(w, http.StatusBadRequest, map[string]interface{}{ + "ok": false, + "error": validationErr.Error(), + "errors": validationErr.Fields, + }) + return + } + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + s.clearChannelDrafts() + writeJSON(w, map[string]interface{}{"ok": true, "committed": true}) +} + func (s *Server) handleWebUIEventsLive(w http.ResponseWriter, r *http.Request) { if !s.checkAuth(r) { http.Error(w, "unauthorized", http.StatusUnauthorized) @@ -225,11 +587,14 @@ func (s *Server) Start(ctx context.Context) error { mux.HandleFunc("/api/sessions", s.handleWebUISessions) mux.HandleFunc("/api/memory", s.handleWebUIMemory) mux.HandleFunc("/api/workspace_file", s.handleWebUIWorkspaceFile) + mux.HandleFunc("/api/workspace_docs", s.handleWebUIWorkspaceDocs) mux.HandleFunc("/api/tool_allowlist_groups", s.handleWebUIToolAllowlistGroups) mux.HandleFunc("/api/tools", s.handleWebUITools) mux.HandleFunc("/api/mcp/install", s.handleWebUIMCPInstall) mux.HandleFunc("/api/logs/live", s.handleWebUILogsLive) mux.HandleFunc("/api/logs/recent", s.handleWebUILogsRecent) + mux.HandleFunc("/api/channels/draft", s.handleWebUIChannelDraft) + mux.HandleFunc("/api/channels/draft/commit", s.handleWebUIChannelDraftCommit) s.extraRoutesMu.RLock() for path, handler := range s.extraRoutes { routePath := path @@ -1227,11 +1592,17 @@ func (s *Server) handleWebUIWeixinLoginStart(w http.ResponseWriter, r *http.Requ http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } - if s.weixinChannel == nil { + cfg, err := s.loadConfig() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _, ch, _ := s.effectiveWeixinRuntime(cfg.Channels.Weixin) + if ch == nil { http.Error(w, "weixin channel unavailable", http.StatusServiceUnavailable) return } - if _, err := s.weixinChannel.StartLogin(r.Context()); err != nil { + if _, err := ch.StartLogin(r.Context()); err != nil { http.Error(w, err.Error(), http.StatusBadGateway) return } @@ -1248,7 +1619,13 @@ func (s *Server) handleWebUIWeixinLoginCancel(w http.ResponseWriter, r *http.Req http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } - if s.weixinChannel == nil { + cfg, err := s.loadConfig() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _, ch, _ := s.effectiveWeixinRuntime(cfg.Channels.Weixin) + if ch == nil { http.Error(w, "weixin channel unavailable", http.StatusServiceUnavailable) return } @@ -1259,7 +1636,7 @@ func (s *Server) handleWebUIWeixinLoginCancel(w http.ResponseWriter, r *http.Req http.Error(w, "invalid json body", http.StatusBadRequest) return } - if !s.weixinChannel.CancelPendingLogin(body.LoginID) { + if !ch.CancelPendingLogin(body.LoginID) { http.Error(w, "login_id not found", http.StatusNotFound) return } @@ -1283,8 +1660,14 @@ func (s *Server) handleWebUIWeixinQR(w http.ResponseWriter, r *http.Request) { } qrCode := "" loginID := strings.TrimSpace(r.URL.Query().Get("login_id")) - if loginID != "" && s.weixinChannel != nil { - if pending := s.weixinChannel.PendingLoginByID(loginID); pending != nil { + cfg, err := s.loadConfig() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _, ch, _ := s.effectiveWeixinRuntime(cfg.Channels.Weixin) + if loginID != "" && ch != nil { + if pending := ch.PendingLoginByID(loginID); pending != nil { qrCode = fallbackString(pending.QRCodeImgContent, pending.QRCode) } } @@ -1318,7 +1701,13 @@ func (s *Server) handleWebUIWeixinAccountRemove(w http.ResponseWriter, r *http.R http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } - if s.weixinChannel == nil { + cfg, err := s.loadConfig() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _, ch, _ := s.effectiveWeixinRuntime(cfg.Channels.Weixin) + if ch == nil { http.Error(w, "weixin channel unavailable", http.StatusServiceUnavailable) return } @@ -1329,7 +1718,7 @@ func (s *Server) handleWebUIWeixinAccountRemove(w http.ResponseWriter, r *http.R http.Error(w, "invalid json body", http.StatusBadRequest) return } - if err := s.weixinChannel.RemoveAccount(body.BotID); err != nil { + if err := ch.RemoveAccount(body.BotID); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } @@ -1346,7 +1735,13 @@ func (s *Server) handleWebUIWeixinAccountDefault(w http.ResponseWriter, r *http. http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } - if s.weixinChannel == nil { + cfg, err := s.loadConfig() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _, ch, _ := s.effectiveWeixinRuntime(cfg.Channels.Weixin) + if ch == nil { http.Error(w, "weixin channel unavailable", http.StatusServiceUnavailable) return } @@ -1357,7 +1752,7 @@ func (s *Server) handleWebUIWeixinAccountDefault(w http.ResponseWriter, r *http. http.Error(w, "invalid json body", http.StatusBadRequest) return } - if err := s.weixinChannel.SetDefaultAccount(body.BotID); err != nil { + if err := ch.SetDefaultAccount(body.BotID); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } @@ -1373,25 +1768,35 @@ func (s *Server) webUIWeixinStatusPayload(ctx context.Context) (map[string]inter "error": err.Error(), }, http.StatusInternalServerError } - weixinCfg := cfg.Channels.Weixin - if s.weixinChannel == nil { + persistedCfg := cloneWeixinConfig(cfg.Channels.Weixin) + weixinCfg, ch, usingDraft := s.effectiveWeixinRuntime(persistedCfg) + if ch == nil { return map[string]interface{}{ - "ok": false, - "enabled": weixinCfg.Enabled, - "base_url": weixinCfg.BaseURL, - "error": "weixin channel unavailable", + "ok": false, + "enabled": weixinCfg.Enabled, + "config_enabled": persistedCfg.Enabled, + "runtime_enabled": false, + "draft_dirty": usingDraft, + "base_url": weixinCfg.BaseURL, + "error": "weixin channel unavailable", }, http.StatusOK } - pendingLogins, err := s.weixinChannel.RefreshLoginStatuses(ctx) + pendingLogins, err := ch.RefreshLoginStatuses(ctx) if err != nil { return map[string]interface{}{ - "ok": false, - "enabled": weixinCfg.Enabled, - "base_url": weixinCfg.BaseURL, - "error": err.Error(), + "ok": false, + "enabled": weixinCfg.Enabled, + "config_enabled": persistedCfg.Enabled, + "runtime_enabled": ch.IsRunning(), + "draft_dirty": usingDraft, + "base_url": weixinCfg.BaseURL, + "error": err.Error(), }, http.StatusOK } - accounts := s.weixinChannel.ListAccounts() + if usingDraft { + weixinCfg = ch.SnapshotConfig() + } + accounts := ch.ListAccounts() pendingPayload := make([]map[string]interface{}, 0, len(pendingLogins)) for _, pending := range pendingLogins { pendingPayload = append(pendingPayload, map[string]interface{}{ @@ -1409,10 +1814,13 @@ func (s *Server) webUIWeixinStatusPayload(ctx context.Context) (map[string]inter firstPending = pendingLogins[0] } return map[string]interface{}{ - "ok": true, - "enabled": weixinCfg.Enabled, - "base_url": fallbackString(weixinCfg.BaseURL, "https://ilinkai.weixin.qq.com"), - "pending_logins": pendingPayload, + "ok": true, + "enabled": weixinCfg.Enabled, + "config_enabled": persistedCfg.Enabled, + "runtime_enabled": ch.IsRunning(), + "draft_dirty": usingDraft, + "base_url": fallbackString(weixinCfg.BaseURL, "https://ilinkai.weixin.qq.com"), + "pending_logins": pendingPayload, "pending_login": map[string]interface{}{ "login_id": pendingString(firstPending, "login_id"), "qr_code": pendingString(firstPending, "qr_code"), @@ -2976,9 +3384,6 @@ func (s *Server) handleWebUIMemory(w http.ResponseWriter, r *http.Request) { path := strings.TrimSpace(r.URL.Query().Get("path")) if path == "" { files := make([]string, 0, 16) - if _, err := os.Stat(filepath.Join(s.workspacePath, "MEMORY.md")); err == nil { - files = append(files, "MEMORY.md") - } entries, err := os.ReadDir(memoryDir) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -2993,11 +3398,7 @@ func (s *Server) handleWebUIMemory(w http.ResponseWriter, r *http.Request) { writeJSON(w, map[string]interface{}{"ok": true, "files": files}) return } - baseDir := memoryDir - if strings.EqualFold(path, "MEMORY.md") { - baseDir = strings.TrimSpace(s.workspacePath) - } - clean, content, found, err := readRelativeTextFile(baseDir, path) + clean, content, found, err := readRelativeTextFile(memoryDir, path) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -3073,6 +3474,70 @@ func (s *Server) handleWebUIWorkspaceFile(w http.ResponseWriter, r *http.Request } } +var workspaceDocFiles = []string{ + "AGENTS.md", + "BOOT.md", + "BOOTSTRAP.md", + "HEARTBEAT.md", + "IDENTITY.md", + "MEMORY.md", + "SOUL.md", + "TOOLS.md", + "USER.md", +} + +func (s *Server) handleWebUIWorkspaceDocs(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + workspace := strings.TrimSpace(s.workspacePath) + path := strings.TrimSpace(r.URL.Query().Get("path")) + if path != "" { + if !isWorkspaceDocAllowed(path) { + http.Error(w, "invalid path", http.StatusBadRequest) + return + } + clean, content, found, err := readRelativeTextFile(workspace, path) + if err != nil { + http.Error(w, err.Error(), relativeFilePathStatus(err)) + return + } + if !found { + http.Error(w, os.ErrNotExist.Error(), http.StatusInternalServerError) + return + } + writeJSON(w, map[string]interface{}{"ok": true, "path": clean, "content": content}) + return + } + files := make([]string, 0, len(workspaceDocFiles)) + for _, name := range workspaceDocFiles { + _, _, found, err := readRelativeTextFile(workspace, name) + if err != nil { + http.Error(w, err.Error(), relativeFilePathStatus(err)) + return + } + if !found { + continue + } + files = append(files, name) + } + writeJSON(w, map[string]interface{}{"ok": true, "files": files}) +} + +func isWorkspaceDocAllowed(name string) bool { + for _, allowed := range workspaceDocFiles { + if name == allowed { + return true + } + } + return false +} + func (s *Server) handleWebUILogsRecent(w http.ResponseWriter, r *http.Request) { if !s.checkAuth(r) { http.Error(w, "unauthorized", http.StatusUnauthorized) diff --git a/pkg/api/server_test.go b/pkg/api/server_test.go index e8191e0..ec7642e 100644 --- a/pkg/api/server_test.go +++ b/pkg/api/server_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/YspCoder/clawgo/pkg/bus" cfgpkg "github.com/YspCoder/clawgo/pkg/config" "github.com/gorilla/websocket" ) @@ -139,6 +140,100 @@ func TestHandleWebUIConfigPostSavesNormalizedConfig(t *testing.T) { } } +func TestHandleWebUIChannelDraftCommitPersistsDrafts(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + cfgPath := filepath.Join(tmp, "config.json") + cfg := cfgpkg.DefaultConfig() + if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil { + t.Fatalf("save config: %v", err) + } + + srv := NewServer("127.0.0.1", 0, "") + srv.SetConfigPath(cfgPath) + srv.SetMessageBus(bus.NewMessageBus()) + hookCalled := 0 + srv.SetConfigAfterHook(func(forceRuntimeReload bool) error { + hookCalled++ + return nil + }) + + draftReq := httptest.NewRequest(http.MethodPost, "/api/channels/draft", strings.NewReader(`{"channel":"telegram","config":{"enabled":true,"token":"bot-token","streaming":true}}`)) + draftReq.Header.Set("Content-Type", "application/json") + draftRec := httptest.NewRecorder() + srv.handleWebUIChannelDraft(draftRec, draftReq) + if draftRec.Code != http.StatusOK { + t.Fatalf("expected 200 from draft save, got %d: %s", draftRec.Code, draftRec.Body.String()) + } + + commitReq := httptest.NewRequest(http.MethodPost, "/api/channels/draft/commit", nil) + commitRec := httptest.NewRecorder() + srv.handleWebUIChannelDraftCommit(commitRec, commitReq) + if commitRec.Code != http.StatusOK { + t.Fatalf("expected 200 from draft commit, got %d: %s", commitRec.Code, commitRec.Body.String()) + } + if hookCalled != 1 { + t.Fatalf("expected reload hook once, got %d", hookCalled) + } + + updated, err := cfgpkg.LoadConfig(cfgPath) + if err != nil { + t.Fatalf("reload config: %v", err) + } + if !updated.Channels.Telegram.Enabled { + t.Fatalf("expected telegram enabled after draft commit") + } + if updated.Channels.Telegram.Token != "bot-token" { + t.Fatalf("expected telegram token to persist, got %q", updated.Channels.Telegram.Token) + } +} + +func TestHandleWebUIWeixinStatusReflectsDraftRuntime(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + cfgPath := filepath.Join(tmp, "config.json") + cfg := cfgpkg.DefaultConfig() + cfg.Channels.Weixin.Enabled = false + if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil { + t.Fatalf("save config: %v", err) + } + + srv := NewServer("127.0.0.1", 0, "") + srv.SetConfigPath(cfgPath) + srv.SetMessageBus(bus.NewMessageBus()) + + draftReq := httptest.NewRequest(http.MethodPost, "/api/channels/draft", strings.NewReader(`{"channel":"weixin","config":{"enabled":true,"base_url":"https://ilinkai.weixin.qq.com"}}`)) + draftReq.Header.Set("Content-Type", "application/json") + draftRec := httptest.NewRecorder() + srv.handleWebUIChannelDraft(draftRec, draftReq) + if draftRec.Code != http.StatusOK { + t.Fatalf("expected 200 from weixin draft save, got %d: %s", draftRec.Code, draftRec.Body.String()) + } + + statusReq := httptest.NewRequest(http.MethodGet, "/api/weixin/status", nil) + statusRec := httptest.NewRecorder() + srv.handleWebUIWeixinStatus(statusRec, statusReq) + if statusRec.Code != http.StatusOK { + t.Fatalf("expected 200 from status, got %d: %s", statusRec.Code, statusRec.Body.String()) + } + + var payload map[string]interface{} + if err := json.Unmarshal(statusRec.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode status: %v", err) + } + if payload["draft_dirty"] != true { + t.Fatalf("expected draft_dirty=true, got %#v", payload["draft_dirty"]) + } + if payload["config_enabled"] != false { + t.Fatalf("expected config_enabled=false, got %#v", payload["config_enabled"]) + } + if payload["runtime_enabled"] != true { + t.Fatalf("expected runtime_enabled=true, got %#v", payload["runtime_enabled"]) + } +} + func TestWithCORSEchoesPreflightHeaders(t *testing.T) { t.Parallel() @@ -255,13 +350,10 @@ func TestSaveProviderConfigForcesRuntimeReload(t *testing.T) { } } -func TestHandleWebUIMemoryListsAndReadsWorkspaceMemoryFile(t *testing.T) { +func TestHandleWebUIMemoryListsAndReadsMemoryDirFile(t *testing.T) { t.Parallel() tmp := t.TempDir() - if err := os.WriteFile(filepath.Join(tmp, "MEMORY.md"), []byte("# long-term\n"), 0o644); err != nil { - t.Fatalf("write workspace memory: %v", err) - } if err := os.MkdirAll(filepath.Join(tmp, "memory"), 0o755); err != nil { t.Fatalf("mkdir memory dir: %v", err) } @@ -285,11 +377,11 @@ func TestHandleWebUIMemoryListsAndReadsWorkspaceMemoryFile(t *testing.T) { if err := json.Unmarshal(listRec.Body.Bytes(), &listPayload); err != nil { t.Fatalf("decode list payload: %v", err) } - if len(listPayload.Files) < 2 || listPayload.Files[0] != "MEMORY.md" { - t.Fatalf("expected MEMORY.md in memory file list, got %+v", listPayload.Files) + if len(listPayload.Files) != 1 || listPayload.Files[0] != "2026-03-19.md" { + t.Fatalf("expected only memory dir files, got %+v", listPayload.Files) } - readReq := httptest.NewRequest(http.MethodGet, "/api/memory?path=MEMORY.md", nil) + readReq := httptest.NewRequest(http.MethodGet, "/api/memory?path=2026-03-19.md", nil) readRec := httptest.NewRecorder() srv.handleWebUIMemory(readRec, readReq) if readRec.Code != http.StatusOK { @@ -303,11 +395,71 @@ func TestHandleWebUIMemoryListsAndReadsWorkspaceMemoryFile(t *testing.T) { if err := json.Unmarshal(readRec.Body.Bytes(), &readPayload); err != nil { t.Fatalf("decode read payload: %v", err) } - if readPayload.Path != "MEMORY.md" || readPayload.Content != "# long-term\n" { + if readPayload.Path != "2026-03-19.md" || readPayload.Content != "daily\n" { t.Fatalf("unexpected memory payload: %+v", readPayload) } } +func TestHandleWebUIWorkspaceDocsListAndRead(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + if err := os.WriteFile(filepath.Join(tmp, "AGENTS.md"), []byte("agents\n"), 0o644); err != nil { + t.Fatalf("write AGENTS.md: %v", err) + } + if err := os.WriteFile(filepath.Join(tmp, "MEMORY.md"), []byte("memory\n"), 0o644); err != nil { + t.Fatalf("write MEMORY.md: %v", err) + } + + srv := NewServer("127.0.0.1", 0, "") + srv.SetWorkspacePath(tmp) + + req := httptest.NewRequest(http.MethodGet, "/api/workspace_docs", nil) + rec := httptest.NewRecorder() + srv.handleWebUIWorkspaceDocs(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String()) + } + + var payload struct { + OK bool `json:"ok"` + Files []string `json:"files"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode payload: %v", err) + } + if !payload.OK { + t.Fatalf("expected ok=true, got %+v", payload) + } + if len(payload.Files) != 2 { + t.Fatalf("expected 2 existing docs, got %+v", payload.Files) + } + if payload.Files[0] != "AGENTS.md" { + t.Fatalf("unexpected first doc payload: %+v", payload.Files[0]) + } + if payload.Files[1] != "MEMORY.md" { + t.Fatalf("unexpected second doc payload: %+v", payload.Files[1]) + } + + readReq := httptest.NewRequest(http.MethodGet, "/api/workspace_docs?path=AGENTS.md", nil) + readRec := httptest.NewRecorder() + srv.handleWebUIWorkspaceDocs(readRec, readReq) + if readRec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", readRec.Code, readRec.Body.String()) + } + var readPayload struct { + OK bool `json:"ok"` + Path string `json:"path"` + Content string `json:"content"` + } + if err := json.Unmarshal(readRec.Body.Bytes(), &readPayload); err != nil { + t.Fatalf("decode read payload: %v", err) + } + if readPayload.Path != "AGENTS.md" || readPayload.Content != "agents\n" { + t.Fatalf("unexpected read payload: %+v", readPayload) + } +} + func TestHandleWebUIChatLive(t *testing.T) { t.Parallel() diff --git a/pkg/channels/weixin.go b/pkg/channels/weixin.go index c64aa90..0fba49b 100644 --- a/pkg/channels/weixin.go +++ b/pkg/channels/weixin.go @@ -389,6 +389,22 @@ func (c *WeixinChannel) ListAccounts() []WeixinAccountSnapshot { return out } +func (c *WeixinChannel) SnapshotConfig() config.WeixinConfig { + c.mu.RLock() + defer c.mu.RUnlock() + + cfgCopy := c.config + cfgCopy.AllowFrom = append([]string(nil), cfgCopy.AllowFrom...) + cfgCopy.Accounts = append([]config.WeixinAccountConfig(nil), c.accountConfigsLocked()...) + cfgCopy.DefaultBotID = strings.TrimSpace(c.defaultBotIDLocked()) + cfgCopy.BotID = "" + cfgCopy.BotToken = "" + cfgCopy.IlinkUserID = "" + cfgCopy.ContextToken = "" + cfgCopy.GetUpdatesBuf = "" + return cfgCopy +} + func (c *WeixinChannel) SetDefaultAccount(botID string) error { botID = strings.TrimSpace(botID) if botID == "" { diff --git a/pkg/config/config.go b/pkg/config/config.go index 7267ddb..cd729af 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -226,6 +226,8 @@ type ProviderConfig struct { APIKey string `json:"api_key" env:"CLAWGO_PROVIDERS_{{.Name}}_API_KEY"` APIBase string `json:"api_base" env:"CLAWGO_PROVIDERS_{{.Name}}_API_BASE"` Models []string `json:"models" env:"CLAWGO_PROVIDERS_{{.Name}}_MODELS"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` SupportsResponsesCompact bool `json:"supports_responses_compact" env:"CLAWGO_PROVIDERS_{{.Name}}_SUPPORTS_RESPONSES_COMPACT"` Auth string `json:"auth" env:"CLAWGO_PROVIDERS_{{.Name}}_AUTH"` TimeoutSec int `json:"timeout_sec" env:"CLAWGO_PROVIDERS_PROXY_TIMEOUT_SEC"` diff --git a/pkg/config/normalized.go b/pkg/config/normalized.go index 9871422..20fe9b3 100644 --- a/pkg/config/normalized.go +++ b/pkg/config/normalized.go @@ -53,6 +53,8 @@ type NormalizedRuntimeRouterConfig struct { type NormalizedRuntimeProviderConfig struct { Auth string `json:"auth,omitempty"` APIBase string `json:"api_base,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` TimeoutSec int `json:"timeout_sec,omitempty"` OAuth ProviderOAuthConfig `json:"oauth,omitempty"` RuntimePersist bool `json:"runtime_persist,omitempty"` @@ -143,6 +145,8 @@ func (c *Config) NormalizedView() NormalizedConfig { view.Runtime.Providers[name] = NormalizedRuntimeProviderConfig{ Auth: pc.Auth, APIBase: pc.APIBase, + MaxTokens: pc.MaxTokens, + Temperature: pc.Temperature, TimeoutSec: pc.TimeoutSec, OAuth: pc.OAuth, RuntimePersist: pc.RuntimePersist, @@ -232,6 +236,12 @@ func (c *Config) ApplyNormalizedView(view NormalizedConfig) { current := c.Models.Providers[name] current.Auth = strings.TrimSpace(item.Auth) current.APIBase = strings.TrimSpace(item.APIBase) + if item.MaxTokens > 0 { + current.MaxTokens = item.MaxTokens + } else if item.MaxTokens == 0 { + current.MaxTokens = 0 + } + current.Temperature = item.Temperature if item.TimeoutSec > 0 { current.TimeoutSec = item.TimeoutSec } diff --git a/pkg/config/normalized_test.go b/pkg/config/normalized_test.go index 31e26a7..520e08b 100644 --- a/pkg/config/normalized_test.go +++ b/pkg/config/normalized_test.go @@ -5,6 +5,13 @@ import "testing" func TestNormalizedViewProjectsCoreAndRuntime(t *testing.T) { cfg := DefaultConfig() cfg.Agents.Router.Enabled = true + cfg.Models.Providers["openai"] = ProviderConfig{ + APIBase: "https://api.openai.com/v1", + Models: []string{"gpt-5.4"}, + MaxTokens: 12288, + Temperature: 0.35, + TimeoutSec: 90, + } cfg.Agents.Subagents["coder"] = SubagentConfig{ Enabled: true, Role: "coding", @@ -27,4 +34,10 @@ func TestNormalizedViewProjectsCoreAndRuntime(t *testing.T) { if !view.Runtime.Router.Enabled || view.Runtime.Router.Strategy != "rules_first" { t.Fatalf("unexpected runtime router: %+v", view.Runtime.Router) } + if got := view.Runtime.Providers["openai"].MaxTokens; got != 12288 { + t.Fatalf("expected provider max_tokens in normalized runtime view, got %d", got) + } + if got := view.Runtime.Providers["openai"].Temperature; got != 0.35 { + t.Fatalf("expected provider temperature in normalized runtime view, got %v", got) + } } diff --git a/pkg/config/validate.go b/pkg/config/validate.go index 529824b..1963e8d 100644 --- a/pkg/config/validate.go +++ b/pkg/config/validate.go @@ -478,6 +478,9 @@ func validateProviderConfig(path string, p ProviderConfig) []error { if p.TimeoutSec <= 0 { errs = append(errs, fmt.Errorf("%s.timeout_sec must be > 0", path)) } + if p.MaxTokens < 0 { + errs = append(errs, fmt.Errorf("%s.max_tokens must be >= 0", path)) + } switch authMode { case "", "bearer", "oauth", "none", "hybrid": default: diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go index 054c1ef..2a8c4e5 100644 --- a/pkg/providers/codex_provider.go +++ b/pkg/providers/codex_provider.go @@ -495,9 +495,7 @@ func (p *CodexProvider) doStreamAttempt(req *http.Request, attempt authAttempt, if typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"])); typ == "response.completed" { completed = true if respObj, ok := obj["response"]; ok { - if b, err := json.Marshal(respObj); err == nil { - finalJSON = b - } + finalJSON = mergeStreamFinalJSON(finalJSON, respObj) } } } diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go index 306c591..43e301d 100644 --- a/pkg/providers/codex_provider_test.go +++ b/pkg/providers/codex_provider_test.go @@ -214,6 +214,27 @@ func TestCodexProviderChatFallsBackToHTTPStreamResponse(t *testing.T) { } } +func TestCodexProviderChatMergesLateUsageFromStreamingCompletion(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = fmt.Fprint(w, "data: {\"type\":\"response.completed\",\"response\":{\"status\":\"completed\",\"output_text\":\"hello\"}}\n\n") + _, _ = fmt.Fprint(w, "data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"total_tokens\":3}}}\n\n") + })) + defer server.Close() + + provider := NewCodexProvider("codex", "test-api-key", server.URL, "gpt-5.4", false, "", 5*time.Second, nil) + resp, err := provider.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-5.4", nil) + if err != nil { + t.Fatalf("Chat error: %v", err) + } + if resp.Content != "hello" { + t.Fatalf("unexpected response content: %q", resp.Content) + } + if resp.Usage == nil || resp.Usage.PromptTokens != 1 || resp.Usage.CompletionTokens != 2 || resp.Usage.TotalTokens != 3 { + t.Fatalf("unexpected usage: %#v", resp.Usage) + } +} + func TestCodexHandleAttemptFailureMarksAPIKeyCooldown(t *testing.T) { provider := NewCodexProvider("codex-websocket-failure", "test-api-key", "", "gpt-5.4", false, "", 5*time.Second, nil) provider.handleAttemptFailure(authAttempt{kind: "api_key", token: "test-api-key"}, http.StatusTooManyRequests, []byte(`{"error":{"message":"rate limit exceeded"}}`)) diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index e8620b5..7da55ff 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -1022,15 +1022,13 @@ func (p *HTTPProvider) doStreamAttempt(req *http.Request, attempt authAttempt, o if err := json.Unmarshal([]byte(payload), &obj); err == nil { if typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"])); typ == "response.completed" { if respObj, ok := obj["response"]; ok { - if b, err := json.Marshal(respObj); err == nil { - finalJSON = b - } + finalJSON = mergeStreamFinalJSON(finalJSON, respObj) } } if choices, ok := obj["choices"]; ok { - if b, err := json.Marshal(map[string]interface{}{"choices": choices, "usage": obj["usage"]}); err == nil { - finalJSON = b - } + finalJSON = mergeStreamFinalJSON(finalJSON, map[string]interface{}{"choices": choices, "usage": obj["usage"]}) + } else if _, ok := obj["usage"]; ok && len(finalJSON) > 0 { + finalJSON = mergeStreamFinalJSON(finalJSON, map[string]interface{}{"usage": obj["usage"]}) } } } @@ -1049,6 +1047,56 @@ func (p *HTTPProvider) doStreamAttempt(req *http.Request, attempt authAttempt, o return finalJSON, resp.StatusCode, ctype, false, nil } +func mergeStreamFinalJSON(existing []byte, incoming interface{}) []byte { + if incoming == nil { + return existing + } + incomingMap, ok := incoming.(map[string]interface{}) + if !ok { + data, err := json.Marshal(incoming) + if err != nil { + return existing + } + return data + } + if len(existing) == 0 { + data, err := json.Marshal(incomingMap) + if err != nil { + return existing + } + return data + } + var merged map[string]interface{} + if err := json.Unmarshal(existing, &merged); err != nil || merged == nil { + merged = map[string]interface{}{} + } + merged = mergeStringAnyMaps(merged, incomingMap) + data, err := json.Marshal(merged) + if err != nil { + return existing + } + return data +} + +func mergeStringAnyMaps(dst, src map[string]interface{}) map[string]interface{} { + if dst == nil { + dst = map[string]interface{}{} + } + for key, value := range src { + if value == nil { + continue + } + if nestedSrc, ok := value.(map[string]interface{}); ok { + if nestedDst, ok := dst[key].(map[string]interface{}); ok { + dst[key] = mergeStringAnyMaps(nestedDst, nestedSrc) + continue + } + } + dst[key] = value + } + return dst +} + func shouldRetryOAuthQuota(status int, body []byte) bool { _, retry := classifyOAuthFailure(status, body) return retry diff --git a/pkg/providers/oauth_test.go b/pkg/providers/oauth_test.go index 78a6ca3..2856674 100644 --- a/pkg/providers/oauth_test.go +++ b/pkg/providers/oauth_test.go @@ -197,6 +197,44 @@ func TestHTTPProviderOAuthSwitchesAccountOnQuota(t *testing.T) { } } +func TestHTTPProviderOpenAICompatStreamMergesLateUsage(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/chat/completions" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: {\"choices\":[{\"index\":0,\"message\":{\"content\":\"hello\"},\"finish_reason\":\"stop\"}]}\n\n")) + _, _ = w.Write([]byte("data: {\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3}}\n\n")) + })) + defer server.Close() + + provider := NewHTTPProvider("openai", "token", server.URL+"/v1", "gpt-test", false, "api_key", 5*time.Second, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, server.URL+"/v1/chat/completions", nil) + if err != nil { + t.Fatalf("new request failed: %v", err) + } + body, status, _, _, err := provider.doStreamAttempt(req, authAttempt{kind: "api_key", token: "token"}, nil) + if err != nil { + t.Fatalf("stream attempt failed: %v", err) + } + if status != http.StatusOK { + t.Fatalf("unexpected status: %d", status) + } + resp, err := parseOpenAICompatResponse(body) + if err != nil { + t.Fatalf("parse response failed: %v", err) + } + if resp.Content != "hello" { + t.Fatalf("unexpected response content: %q", resp.Content) + } + if resp.Usage == nil || resp.Usage.PromptTokens != 1 || resp.Usage.CompletionTokens != 2 || resp.Usage.TotalTokens != 3 { + t.Fatalf("unexpected usage: %#v", resp.Usage) + } +} + func TestOAuthManagerPreRefreshesExpiringSession(t *testing.T) { t.Parallel()