diff --git a/cmd/cmd_gateway.go b/cmd/cmd_gateway.go
index 5afda3a..d2adee4 100644
--- a/cmd/cmd_gateway.go
+++ b/cmd/cmd_gateway.go
@@ -275,6 +275,7 @@ func gatewayCmd() {
registryServer.SetConfigAfterHook(func(forceRuntimeReload bool) error {
return triggerReload("api", forceRuntimeReload)
})
+ registryServer.SetMessageBus(msgBus)
if rawWeixin, ok := channelManager.GetChannel("weixin"); ok {
if weixinChannel, ok := rawWeixin.(*channels.WeixinChannel); ok {
weixinChannel.SetConfigPath(getConfigPath())
diff --git a/config.example.json b/config.example.json
index 018b750..4eda1b1 100644
--- a/config.example.json
+++ b/config.example.json
@@ -178,6 +178,8 @@
"codex": {
"api_base": "https://api.openai.com/v1",
"models": ["gpt-5.4"],
+ "max_tokens": 8192,
+ "temperature": 0.7,
"responses": {
"web_search_enabled": false,
"web_search_context_size": "",
@@ -204,6 +206,8 @@
"gemini": {
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai",
"models": ["gemini-2.5-pro"],
+ "max_tokens": 8192,
+ "temperature": 0.7,
"responses": {
"web_search_enabled": false,
"web_search_context_size": "",
@@ -231,6 +235,8 @@
"api_key": "sk-your-openai-api-key",
"api_base": "https://api.openai.com/v1",
"models": ["gpt-5.4", "gpt-5.4-mini"],
+ "max_tokens": 8192,
+ "temperature": 0.7,
"responses": {
"web_search_enabled": false,
"web_search_context_size": "",
@@ -245,6 +251,8 @@
"anthropic": {
"api_base": "https://api.anthropic.com",
"models": ["claude-sonnet-4-20250514"],
+ "max_tokens": 8192,
+ "temperature": 0.7,
"responses": {
"web_search_enabled": false,
"web_search_context_size": "",
@@ -270,6 +278,8 @@
"qwen": {
"api_base": "https://dashscope.aliyuncs.com/compatible-mode/v1",
"models": ["qwen-max"],
+ "max_tokens": 8192,
+ "temperature": 0.7,
"responses": {
"web_search_enabled": false,
"web_search_context_size": "",
@@ -294,6 +304,8 @@
"kimi": {
"api_base": "https://api.moonshot.cn/v1",
"models": ["kimi-k2-0711-preview"],
+ "max_tokens": 8192,
+ "temperature": 0.7,
"responses": {
"web_search_enabled": false,
"web_search_context_size": "",
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index f8113af..7c12080 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -40,6 +40,8 @@ type AgentLoop struct {
provider providers.LLMProvider
workspace string
model string
+ maxTokens int
+ temperature float64
maxIterations int
sessions *session.SessionManager
contextBuilder *ContextBuilder
@@ -56,6 +58,8 @@ type AgentLoop struct {
providerNames []string
providerPool map[string]providers.LLMProvider
providerResponses map[string]config.ProviderResponsesConfig
+ providerMaxTokens map[string]int
+ providerTemperatures map[string]float64
telegramStreaming bool
providerMu sync.RWMutex
sessionProvider map[string]string
@@ -222,6 +226,8 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
provider: provider,
workspace: workspace,
model: provider.GetDefaultModel(),
+ maxTokens: cfg.Agents.Defaults.MaxTokens,
+ temperature: cfg.Agents.Defaults.Temperature,
maxIterations: cfg.Agents.Defaults.MaxToolIterations,
sessions: sessionsManager,
contextBuilder: NewContextBuilder(workspace, func() []string { return toolsRegistry.GetSummaries() }),
@@ -237,6 +243,8 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
sessionProvider: map[string]string{},
sessionStreamed: map[string]bool{},
providerResponses: map[string]config.ProviderResponsesConfig{},
+ providerMaxTokens: map[string]int{},
+ providerTemperatures: map[string]float64{},
telegramStreaming: cfg.Channels.Telegram.Streaming,
subagentManager: subagentManager,
subagentRouter: subagentRouter,
@@ -265,6 +273,8 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
loop.providerNames = append(loop.providerNames, primaryName)
if pc, ok := config.ProviderConfigByName(cfg, primaryName); ok {
loop.providerResponses[primaryName] = pc.Responses
+ loop.providerMaxTokens[primaryName] = pc.MaxTokens
+ loop.providerTemperatures[primaryName] = pc.Temperature
}
seenProviders := map[string]struct{}{primaryName: {}}
providerConfigs := config.AllProviderConfigs(cfg)
@@ -304,6 +314,8 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
modelName = strings.TrimSpace(pc.Models[0])
}
loop.providerResponses[providerName] = pc.Responses
+ loop.providerMaxTokens[providerName] = pc.MaxTokens
+ loop.providerTemperatures[providerName] = pc.Temperature
}
seenProviders[providerName] = struct{}{}
loop.providerNames = append(loop.providerNames, providerName)
@@ -464,7 +476,7 @@ func (al *AgentLoop) buildSessionShards(ctx context.Context) []chan bus.InboundM
return shards
}
-func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMessage, messages []providers.Message, toolDefs []providers.ToolDefinition, options map[string]interface{}, primaryErr error) (*providers.LLMResponse, string, error) {
+func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMessage, messages []providers.Message, toolDefs []providers.ToolDefinition, primaryErr error) (*providers.LLMResponse, string, error) {
if len(al.providerChain) <= 1 {
return nil, "", primaryErr
}
@@ -495,8 +507,10 @@ func (al *AgentLoop) tryFallbackProviders(ctx context.Context, msg bus.InboundMe
lastErr = err
continue
}
- resp, err := p.Chat(ctx, messages, toolDefs, candidateModel, options)
+ fallbackOptions := al.buildResponsesOptionsForProvider(msg.SessionKey, candidate.name, int64(al.maxTokensForProvider(candidate.name)), al.temperatureForProvider(candidate.name))
+ resp, err := p.Chat(ctx, messages, toolDefs, candidateModel, fallbackOptions)
if err == nil {
+ al.setSessionProvider(msg.SessionKey, candidate.name)
logger.WarnCF("agent", logger.C0150, map[string]interface{}{"provider": candidate.name, "model": candidateModel, "ref": candidate.ref})
return resp, candidate.name, nil
}
@@ -550,6 +564,76 @@ func (al *AgentLoop) ensureProviderCandidate(candidate providerCandidate) (provi
return created, model, nil
}
+func (al *AgentLoop) providerCandidateByName(name string) (providerCandidate, bool) {
+ if al == nil {
+ return providerCandidate{}, false
+ }
+ target := strings.TrimSpace(name)
+ if target == "" {
+ return providerCandidate{}, false
+ }
+ for _, candidate := range al.providerChain {
+ if strings.EqualFold(strings.TrimSpace(candidate.name), target) {
+ return candidate, true
+ }
+ }
+ return providerCandidate{}, false
+}
+
+func (al *AgentLoop) defaultProviderName() string {
+ if al == nil || len(al.providerNames) == 0 {
+ return ""
+ }
+ return strings.TrimSpace(al.providerNames[0])
+}
+
+func (al *AgentLoop) sessionProviderName(sessionKey string) string {
+ name := strings.TrimSpace(al.getSessionProvider(sessionKey))
+ if name == "" {
+ name = al.defaultProviderName()
+ }
+ return name
+}
+
+func (al *AgentLoop) isKnownProviderName(name string) bool {
+ target := strings.TrimSpace(name)
+ if target == "" {
+ return false
+ }
+ for _, item := range al.providerNames {
+ if strings.EqualFold(strings.TrimSpace(item), target) {
+ return true
+ }
+ }
+ return false
+}
+
+func (al *AgentLoop) activeProviderForSession(sessionKey string) (providers.LLMProvider, string, string, error) {
+ if al == nil {
+ return nil, "", "", fmt.Errorf("agent loop is nil")
+ }
+ name := al.sessionProviderName(sessionKey)
+ if name == "" {
+ return al.provider, al.model, "", nil
+ }
+ if strings.EqualFold(name, al.defaultProviderName()) {
+ model := strings.TrimSpace(al.model)
+ if model == "" && al.provider != nil {
+ model = strings.TrimSpace(al.provider.GetDefaultModel())
+ }
+ return al.provider, model, name, nil
+ }
+ candidate, ok := al.providerCandidateByName(name)
+ if !ok {
+ return al.provider, al.model, name, nil
+ }
+ p, model, err := al.ensureProviderCandidate(candidate)
+ if err != nil {
+ return nil, "", name, err
+ }
+ return p, model, name, nil
+}
+
func automaticFallbackPriority(name string) int {
switch normalizeFallbackProviderName(name) {
case "claude":
@@ -610,10 +694,22 @@ func (al *AgentLoop) getSessionProvider(sessionKey string) string {
}
func (al *AgentLoop) syncSessionDefaultProvider(sessionKey string) {
- if al == nil || len(al.providerNames) == 0 {
+ if al == nil {
return
}
- al.setSessionProvider(sessionKey, al.providerNames[0])
+ current := strings.TrimSpace(al.getSessionProvider(sessionKey))
+ if current == "" {
+ if name := al.defaultProviderName(); name != "" {
+ al.setSessionProvider(sessionKey, name)
+ }
+ return
+ }
+ if al.isKnownProviderName(current) {
+ return
+ }
+ if name := al.defaultProviderName(); name != "" {
+ al.setSessionProvider(sessionKey, name)
+ }
}
func (al *AgentLoop) markSessionStreamed(sessionKey string) {
@@ -707,6 +803,8 @@ func sessionShardIndex(sessionKey string, shardCount int) int {
return int(h.Sum32() % uint32(shardCount))
}
+var thinkTagPattern = regexp.MustCompile(`(?s).*?`)
+
func (al *AgentLoop) getTrigger(msg bus.InboundMessage) string {
if msg.Metadata != nil {
if t := strings.TrimSpace(msg.Metadata["trigger"]); t != "" {
@@ -748,6 +846,413 @@ func (al *AgentLoop) shouldSuppressOutbound(msg bus.InboundMessage, response str
return len(r) <= maxChars
}
+type llmTurnLoopConfig struct {
+ ctx context.Context
+ triggerMsg bus.InboundMessage
+ sessionKey string
+ toolChannel string
+ toolChatID string
+ messages []providers.Message
+ media []string
+ mediaItems []bus.MediaItem
+ enableStreaming bool
+ errorLogCode logger.CodeID
+ logDirectResponse bool
+}
+
+type llmTurnLoopResult struct {
+ messages []providers.Message
+ pendingPersist []providers.Message
+ finalContent string
+ iteration int
+ hasToolActivity bool
+}
+
+func logLLMTurnRequest(iteration, maxIterations int, providerName, activeModel string, messages []providers.Message, providerToolDefs []providers.ToolDefinition, maxTokens int, temperature float64) {
+ systemPromptLen := 0
+ if len(messages) > 0 {
+ systemPromptLen = len(messages[0].Content)
+ }
+ logger.DebugCF("agent", logger.C0152, map[string]interface{}{
+ "iteration": iteration,
+ "max": maxIterations,
+ "provider": providerName,
+ "model": activeModel,
+ "messages_count": len(messages),
+ "tools_count": len(providerToolDefs),
+ "max_tokens": maxTokens,
+ "temperature": temperature,
+ "system_prompt_len": systemPromptLen,
+ })
+ if iteration == 1 {
+ logger.DebugCF("agent", logger.C0153, map[string]interface{}{
+ "iteration": iteration,
+ "messages_json": formatMessagesForLog(messages),
+ "tools_json": formatToolsForLog(providerToolDefs),
+ })
+ }
+}
+
+func logLLMDirectResponse(iteration int, finalContent string) {
+ logger.InfoCF("agent", logger.C0156, map[string]interface{}{
+ "iteration": iteration,
+ "content_chars": len(finalContent),
+ })
+}
+
+func logLLMToolCalls(iteration int, toolCalls []providers.ToolCall) {
+ toolNames := make([]string, 0, len(toolCalls))
+ for _, tc := range toolCalls {
+ toolNames = append(toolNames, tc.Name)
+ }
+ logger.InfoCF("agent", logger.C0157, map[string]interface{}{
+ "tools": toolNames,
+ "count": len(toolNames),
+ "iteration": iteration,
+ })
+}
+
+func buildAssistantToolCallMessage(response *providers.LLMResponse) providers.Message {
+ assistantMsg := providers.Message{
+ Role: "assistant",
+ Content: response.Content,
+ }
+ if response == nil {
+ return assistantMsg
+ }
+ for _, tc := range response.ToolCalls {
+ argumentsJSON, _ := json.Marshal(tc.Arguments)
+ assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
+ ID: tc.ID,
+ Type: "function",
+ Function: &providers.FunctionCall{
+ Name: tc.Name,
+ Arguments: string(argumentsJSON),
+ },
+ })
+ }
+ return assistantMsg
+}
+
+func (al *AgentLoop) executeResponseToolCalls(cfg llmTurnLoopConfig, iteration int, response *providers.LLMResponse) []providers.Message {
+ if response == nil || len(response.ToolCalls) == 0 {
+ return nil
+ }
+ results := make([]providers.Message, 0, len(response.ToolCalls))
+ for _, tc := range response.ToolCalls {
+ argsJSON, _ := json.Marshal(tc.Arguments)
+ logger.InfoCF("agent", logger.C0172, map[string]interface{}{
+ "tool": tc.Name,
+ "args": truncate(string(argsJSON), 200),
+ "iteration": iteration,
+ })
+ execArgs := withToolContextArgs(tc.Name, tc.Arguments, cfg.toolChannel, cfg.toolChatID)
+ toolResult, toolErr := al.executeToolCall(cfg.ctx, tc.Name, execArgs, cfg.toolChannel, cfg.toolChatID)
+ if toolErr != nil {
+ toolResult = fmt.Sprintf("Error: %v", toolErr)
+ }
+ results = append(results, providers.Message{
+ Role: "tool",
+ Content: toolResult,
+ ToolCallID: tc.ID,
+ })
+ }
+ return results
+}
+
+func (al *AgentLoop) requestLLMResponse(cfg llmTurnLoopConfig, activeProvider providers.LLMProvider, activeModel string, messages []providers.Message, providerToolDefs []providers.ToolDefinition, options map[string]interface{}) (*providers.LLMResponse, error) {
+ if cfg.enableStreaming {
+ if sp, ok := activeProvider.(providers.StreamingLLMProvider); ok {
+ streamText := ""
+ lastPush := time.Now().Add(-time.Second)
+ return sp.ChatStream(cfg.ctx, messages, providerToolDefs, activeModel, options, func(delta string) {
+ if strings.TrimSpace(delta) == "" {
+ return
+ }
+ streamText += delta
+ if time.Since(lastPush) < 450*time.Millisecond {
+ return
+ }
+ if !shouldFlushTelegramStreamSnapshot(streamText) {
+ return
+ }
+ lastPush = time.Now()
+ replyID := ""
+ if cfg.triggerMsg.Metadata != nil {
+ replyID = cfg.triggerMsg.Metadata["message_id"]
+ }
+ al.bus.PublishOutbound(bus.OutboundMessage{
+ Channel: cfg.toolChannel,
+ ChatID: cfg.toolChatID,
+ Content: streamText,
+ Action: "stream",
+ ReplyToID: replyID,
+ })
+ al.markSessionStreamed(cfg.sessionKey)
+ })
+ }
+ }
+ return activeProvider.Chat(cfg.ctx, messages, providerToolDefs, activeModel, options)
+}
+
+func (al *AgentLoop) runLLMTurnLoop(cfg llmTurnLoopConfig) (llmTurnLoopResult, error) {
+ result := llmTurnLoopResult{
+ messages: append([]providers.Message(nil), cfg.messages...),
+ pendingPersist: make([]providers.Message, 0, 16),
+ }
+ maxAllowed := al.maxIterations
+ if maxAllowed < 1 {
+ maxAllowed = 1
+ }
+ toolDefs := al.filteredToolDefinitionsForContext(cfg.ctx)
+ providerToolDefs := al.buildProviderToolDefs(toolDefs)
+ result.messages = injectResponsesMediaParts(result.messages, cfg.media, cfg.mediaItems)
+
+ for result.iteration < maxAllowed {
+ result.iteration++
+ activeProvider, activeModel, providerName, err := al.activeProviderForSession(cfg.sessionKey)
+ if err != nil {
+ logger.ErrorCF("agent", cfg.errorLogCode, map[string]interface{}{
+ "iteration": result.iteration,
+ "error": err.Error(),
+ })
+ return result, fmt.Errorf("resolve active provider: %w", err)
+ }
+ if activeProvider == nil {
+ return result, fmt.Errorf("active provider unavailable for session %s", strings.TrimSpace(cfg.sessionKey))
+ }
+
+ maxTokens := al.maxTokensForProvider(providerName)
+ temperature := al.temperatureForProvider(providerName)
+ logLLMTurnRequest(result.iteration, al.maxIterations, providerName, activeModel, result.messages, providerToolDefs, maxTokens, temperature)
+
+ options := al.buildResponsesOptions(cfg.sessionKey, int64(maxTokens), temperature)
+ response, err := al.requestLLMResponse(cfg, activeProvider, activeModel, result.messages, providerToolDefs, options)
+
+ if err != nil {
+ if fb, _, ferr := al.tryFallbackProviders(cfg.ctx, cfg.triggerMsg, result.messages, providerToolDefs, err); ferr == nil && fb != nil {
+ response = fb
+ err = nil
+ } else {
+ err = ferr
+ }
+ }
+ if err != nil {
+ logger.ErrorCF("agent", cfg.errorLogCode, map[string]interface{}{
+ "iteration": result.iteration,
+ "error": err.Error(),
+ })
+ return result, fmt.Errorf("LLM call failed: %w", err)
+ }
+
+ if len(response.ToolCalls) == 0 {
+ result.finalContent = response.Content
+ if cfg.logDirectResponse {
+ logLLMDirectResponse(result.iteration, result.finalContent)
+ }
+ return result, nil
+ }
+
+ logLLMToolCalls(result.iteration, response.ToolCalls)
+
+ assistantMsg := buildAssistantToolCallMessage(response)
+ result.messages = append(result.messages, assistantMsg)
+ result.pendingPersist = append(result.pendingPersist, assistantMsg)
+ result.hasToolActivity = true
+ if maxAllowed < result.iteration+al.maxIterations {
+ maxAllowed = result.iteration + al.maxIterations
+ }
+
+ for _, toolResultMsg := range al.executeResponseToolCalls(cfg, result.iteration, response) {
+ result.messages = append(result.messages, toolResultMsg)
+ result.pendingPersist = append(result.pendingPersist, toolResultMsg)
+ }
+ }
+
+ return result, nil
+}
+
+func (al *AgentLoop) logInboundMessageStart(msg bus.InboundMessage) {
+ logger.InfoCF("agent", logger.C0171, map[string]interface{}{
+ "channel": msg.Channel,
+ "chat_id": msg.ChatID,
+ "sender_id": msg.SenderID,
+ "session_key": msg.SessionKey,
+ "preview": truncate(msg.Content, 80),
+ })
+}
+
+func (al *AgentLoop) prepareUserMessageContext(msg bus.InboundMessage, memoryNamespace string) ([]providers.Message, string) {
+ history := al.sessions.GetHistory(msg.SessionKey)
+ summary := al.sessions.GetSummary(msg.SessionKey)
+ al.sessions.AddMessage(msg.SessionKey, "user", msg.Content)
+ if explicitPref := ExtractLanguagePreference(msg.Content); explicitPref != "" {
+ al.sessions.SetPreferredLanguage(msg.SessionKey, explicitPref)
+ }
+ preferredLang, lastLang := al.sessions.GetLanguagePreferences(msg.SessionKey)
+ responseLang := DetectResponseLanguage(msg.Content, preferredLang, lastLang)
+ messages := al.contextBuilder.BuildMessagesWithMemoryNamespace(
+ history,
+ summary,
+ msg.Content,
+ nil,
+ msg.Channel,
+ msg.ChatID,
+ responseLang,
+ memoryNamespace,
+ )
+ return messages, responseLang
+}
+
+func (al *AgentLoop) finalizeUserMessage(sessionKey, responseLang string, pendingPersist []providers.Message, finalContent string) {
+ for _, persisted := range pendingPersist {
+ al.sessions.AddMessageFull(sessionKey, persisted)
+ }
+ al.sessions.AddMessageFull(sessionKey, providers.Message{
+ Role: "assistant",
+ Content: finalContent,
+ })
+ al.sessions.SetLastLanguage(sessionKey, responseLang)
+ al.compactSessionIfNeeded(sessionKey)
+ _ = al.sessions.Save(al.sessions.GetOrCreate(sessionKey))
+}
+
+func (al *AgentLoop) prepareSystemMessageContext(sessionKey string, msg bus.InboundMessage, originChannel, originChatID string) ([]providers.Message, string) {
+ history := al.sessions.GetHistory(sessionKey)
+ summary := al.sessions.GetSummary(sessionKey)
+ preferredLang, lastLang := al.sessions.GetLanguagePreferences(sessionKey)
+ responseLang := DetectResponseLanguage(msg.Content, preferredLang, lastLang)
+ messages := al.contextBuilder.BuildMessages(
+ history,
+ summary,
+ msg.Content,
+ nil,
+ originChannel,
+ originChatID,
+ responseLang,
+ )
+ return messages, responseLang
+}
+
+func (al *AgentLoop) finalizeSystemMessage(sessionKey, responseLang string, msg bus.InboundMessage, pendingPersist []providers.Message, finalContent string) {
+ al.sessions.AddMessage(sessionKey, "user", fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content))
+ for _, persisted := range pendingPersist {
+ al.sessions.AddMessageFull(sessionKey, persisted)
+ }
+ al.sessions.AddMessageFull(sessionKey, providers.Message{
+ Role: "assistant",
+ Content: finalContent,
+ })
+ al.sessions.SetLastLanguage(sessionKey, responseLang)
+ al.compactSessionIfNeeded(sessionKey)
+ _ = al.sessions.Save(al.sessions.GetOrCreate(sessionKey))
+}
+
+func (al *AgentLoop) startSpecTaskForMessage(msg bus.InboundMessage) specCodingTaskRef {
+ specTaskRef := specCodingTaskRef{}
+ if err := al.maybeEnsureSpecCodingDocs(msg.Content); err != nil {
+ logger.WarnCF("agent", logger.C0172, map[string]interface{}{
+ "session_key": msg.SessionKey,
+ "error": err.Error(),
+ })
+ }
+ taskRef, err := al.maybeStartSpecCodingTask(msg.Content)
+ if err != nil {
+ logger.WarnCF("agent", logger.C0172, map[string]interface{}{
+ "session_key": msg.SessionKey,
+ "error": err.Error(),
+ })
+ return specTaskRef
+ }
+ return normalizeSpecCodingTaskRef(taskRef)
+}
+
+func (al *AgentLoop) reopenSpecTaskOnError(specTaskRef specCodingTaskRef, msg bus.InboundMessage, err error) {
+ if specTaskRef.Summary == "" || err == nil {
+ return
+ }
+ if rerr := al.maybeReopenSpecCodingTask(specTaskRef, msg.Content, err.Error()); rerr != nil {
+ logger.WarnCF("agent", logger.C0172, map[string]interface{}{
+ "session_key": msg.SessionKey,
+ "error": rerr.Error(),
+ })
+ }
+}
+
+func (al *AgentLoop) completeSpecTaskOnSuccess(specTaskRef specCodingTaskRef, msg bus.InboundMessage, output string) {
+ if specTaskRef.Summary == "" {
+ return
+ }
+ if err := al.maybeCompleteSpecCodingTask(specTaskRef, output); err != nil {
+ logger.WarnCF("agent", logger.C0172, map[string]interface{}{
+ "session_key": msg.SessionKey,
+ "error": err.Error(),
+ })
+ }
+}
+
+func (al *AgentLoop) recoverFinalContentAfterToolCalls(ctx context.Context, sessionKey string, messages []providers.Message, hasToolActivity bool) string {
+ if !hasToolActivity {
+ return ""
+ }
+ activeProvider, activeModel, providerName, err := al.activeProviderForSession(sessionKey)
+ if err != nil {
+ logger.WarnCF("agent", logger.C0172, map[string]interface{}{
+ "session_key": sessionKey,
+ "error": err.Error(),
+ })
+ return ""
+ }
+ if activeProvider == nil {
+ return ""
+ }
+ options := al.buildResponsesOptionsForProvider(sessionKey, providerName, int64(al.maxTokensForProvider(providerName)), 0.2)
+ forced, ferr := activeProvider.Chat(ctx, messages, nil, activeModel, options)
+ if ferr != nil || forced == nil {
+ if ferr != nil {
+ logger.WarnCF("agent", logger.C0172, map[string]interface{}{
+ "session_key": sessionKey,
+ "error": ferr.Error(),
+ })
+ }
+ return ""
+ }
+ return forced.Content
+}
+
+func sanitizeUserVisibleContent(finalContent string, iteration int) string {
+ userContent := thinkTagPattern.ReplaceAllString(finalContent, "")
+ if userContent == "" && finalContent != "" && iteration == 1 {
+ return "Thinking process completed."
+ }
+ return userContent
+}
+
+func (al *AgentLoop) finalizeUserTurnResponse(ctx context.Context, msg bus.InboundMessage, responseLang string, loopResult llmTurnLoopResult) (string, string) {
+ finalContent := loopResult.finalContent
+ if finalContent == "" {
+ if recovered := al.recoverFinalContentAfterToolCalls(ctx, msg.SessionKey, loopResult.messages, loopResult.hasToolActivity); recovered != "" {
+ finalContent = recovered
+ }
+ }
+ userContent := sanitizeUserVisibleContent(finalContent, loopResult.iteration)
+ al.finalizeUserMessage(msg.SessionKey, responseLang, loopResult.pendingPersist, userContent)
+ return finalContent, userContent
+}
+
+func (al *AgentLoop) maybeHandleAutoRoute(ctx context.Context, msg bus.InboundMessage, specTaskRef specCodingTaskRef) (string, error, bool) {
+ routed, ok, routeErr := al.maybeAutoRoute(ctx, msg)
+ if !ok {
+ return "", nil, false
+ }
+ if routeErr != nil {
+ al.reopenSpecTaskOnError(specTaskRef, msg, routeErr)
+ return routed, routeErr, true
+ }
+ al.completeSpecTaskOnSuccess(specTaskRef, msg, routed)
+ return routed, nil, true
+}
+
func loadHeartbeatAckToken(workspace string) string {
workspace = strings.TrimSpace(workspace)
if workspace == "" {
@@ -879,282 +1384,37 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
}
defer release()
al.syncSessionDefaultProvider(msg.SessionKey)
- // Add message preview to log
- preview := truncate(msg.Content, 80)
- logger.InfoCF("agent", logger.C0171,
- map[string]interface{}{
- "channel": msg.Channel,
- "chat_id": msg.ChatID,
- "sender_id": msg.SenderID,
- "session_key": msg.SessionKey,
- "preview": preview,
- })
+ al.logInboundMessageStart(msg)
// Route system messages to processSystemMessage
if msg.Channel == "system" {
return al.processSystemMessage(ctx, msg)
}
- specTaskRef := specCodingTaskRef{}
- if err := al.maybeEnsureSpecCodingDocs(msg.Content); err != nil {
- logger.WarnCF("agent", logger.C0172, map[string]interface{}{
- "session_key": msg.SessionKey,
- "error": err.Error(),
- })
- }
- if taskRef, err := al.maybeStartSpecCodingTask(msg.Content); err != nil {
- logger.WarnCF("agent", logger.C0172, map[string]interface{}{
- "session_key": msg.SessionKey,
- "error": err.Error(),
- })
- } else {
- specTaskRef = normalizeSpecCodingTaskRef(taskRef)
- }
- if routed, ok, routeErr := al.maybeAutoRoute(ctx, msg); ok {
- if routeErr != nil && specTaskRef.Summary != "" {
- if err := al.maybeReopenSpecCodingTask(specTaskRef, msg.Content, routeErr.Error()); err != nil {
- logger.WarnCF("agent", logger.C0172, map[string]interface{}{
- "session_key": msg.SessionKey,
- "error": err.Error(),
- })
- }
- }
- if routeErr == nil && specTaskRef.Summary != "" {
- if err := al.maybeCompleteSpecCodingTask(specTaskRef, routed); err != nil {
- logger.WarnCF("agent", logger.C0172, map[string]interface{}{
- "session_key": msg.SessionKey,
- "error": err.Error(),
- })
- }
- }
+ specTaskRef := al.startSpecTaskForMessage(msg)
+ if routed, routeErr, handled := al.maybeHandleAutoRoute(ctx, msg, specTaskRef); handled {
return routed, routeErr
}
- history := al.sessions.GetHistory(msg.SessionKey)
- summary := al.sessions.GetSummary(msg.SessionKey)
- if explicitPref := ExtractLanguagePreference(msg.Content); explicitPref != "" {
- al.sessions.SetPreferredLanguage(msg.SessionKey, explicitPref)
- }
- preferredLang, lastLang := al.sessions.GetLanguagePreferences(msg.SessionKey)
- responseLang := DetectResponseLanguage(msg.Content, preferredLang, lastLang)
+ messages, responseLang := al.prepareUserMessageContext(msg, memoryNamespace)
- messages := al.contextBuilder.BuildMessagesWithMemoryNamespace(
- history,
- summary,
- msg.Content,
- nil,
- msg.Channel,
- msg.ChatID,
- responseLang,
- memoryNamespace,
- )
-
- iteration := 0
- var finalContent string
- hasToolActivity := false
- lastToolOutputs := make([]string, 0, 4)
- maxAllowed := al.maxIterations
- if maxAllowed < 1 {
- maxAllowed = 1
- }
- for iteration < maxAllowed {
- iteration++
-
- logger.DebugCF("agent", logger.C0151,
- map[string]interface{}{
- "iteration": iteration,
- "max": al.maxIterations,
- })
-
- toolDefs := al.filteredToolDefinitionsForContext(ctx)
- providerToolDefs := al.buildProviderToolDefs(toolDefs)
-
- // Log LLM request details
- logger.DebugCF("agent", logger.C0152,
- map[string]interface{}{
- "iteration": iteration,
- "model": al.model,
- "messages_count": len(messages),
- "tools_count": len(providerToolDefs),
- "max_tokens": 8192,
- "temperature": 0.7,
- "system_prompt_len": len(messages[0].Content),
- })
-
- // Log full messages (detailed)
- logger.DebugCF("agent", logger.C0153,
- map[string]interface{}{
- "iteration": iteration,
- "messages_json": formatMessagesForLog(messages),
- "tools_json": formatToolsForLog(providerToolDefs),
- })
-
- messages = injectResponsesMediaParts(messages, msg.Media, msg.MediaItems)
- options := al.buildResponsesOptions(msg.SessionKey, 8192, 0.7)
- var response *providers.LLMResponse
- var err error
- if msg.Channel == "telegram" && al.telegramStreaming {
- if sp, ok := al.provider.(providers.StreamingLLMProvider); ok {
- streamText := ""
- lastPush := time.Now().Add(-time.Second)
- response, err = sp.ChatStream(ctx, messages, providerToolDefs, al.model, options, func(delta string) {
- if strings.TrimSpace(delta) == "" {
- return
- }
- streamText += delta
- if time.Since(lastPush) < 450*time.Millisecond {
- return
- }
- if !shouldFlushTelegramStreamSnapshot(streamText) {
- return
- }
- lastPush = time.Now()
- replyID := ""
- if msg.Metadata != nil {
- replyID = msg.Metadata["message_id"]
- }
- // Stream with formatted rendering once snapshot is syntactically safe.
- al.bus.PublishOutbound(bus.OutboundMessage{Channel: msg.Channel, ChatID: msg.ChatID, Content: streamText, Action: "stream", ReplyToID: replyID})
- al.markSessionStreamed(msg.SessionKey)
- })
- } else {
- response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, options)
- }
- } else {
- response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, options)
- }
-
- if err != nil {
- if fb, _, ferr := al.tryFallbackProviders(ctx, msg, messages, providerToolDefs, options, err); ferr == nil && fb != nil {
- response = fb
- err = nil
- } else {
- err = ferr
- }
- }
- if err != nil {
- logger.ErrorCF("agent", logger.C0155,
- map[string]interface{}{
- "iteration": iteration,
- "error": err.Error(),
- })
- if specTaskRef.Summary != "" {
- if rerr := al.maybeReopenSpecCodingTask(specTaskRef, msg.Content, err.Error()); rerr != nil {
- logger.WarnCF("agent", logger.C0172, map[string]interface{}{
- "session_key": msg.SessionKey,
- "error": rerr.Error(),
- })
- }
- }
- return "", fmt.Errorf("LLM call failed: %w", err)
- }
-
- if len(response.ToolCalls) == 0 {
- finalContent = response.Content
- logger.InfoCF("agent", logger.C0156,
- map[string]interface{}{
- "iteration": iteration,
- "content_chars": len(finalContent),
- })
- break
- }
-
- toolNames := make([]string, 0, len(response.ToolCalls))
- for _, tc := range response.ToolCalls {
- toolNames = append(toolNames, tc.Name)
- }
- logger.InfoCF("agent", logger.C0157,
- map[string]interface{}{
- "tools": toolNames,
- "count": len(toolNames),
- "iteration": iteration,
- })
-
- assistantMsg := providers.Message{
- Role: "assistant",
- Content: response.Content,
- }
-
- for _, tc := range response.ToolCalls {
- argumentsJSON, _ := json.Marshal(tc.Arguments)
- assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
- ID: tc.ID,
- Type: "function",
- Function: &providers.FunctionCall{
- Name: tc.Name,
- Arguments: string(argumentsJSON),
- },
- })
- }
- messages = append(messages, assistantMsg)
- // Persist assistant message with tool calls.
- al.sessions.AddMessageFull(msg.SessionKey, assistantMsg)
-
- hasToolActivity = true
- // Extend rolling window as long as tools keep chaining.
- if maxAllowed < iteration+al.maxIterations {
- maxAllowed = iteration + al.maxIterations
- }
- for _, tc := range response.ToolCalls {
- // Log tool call with arguments preview
- argsJSON, _ := json.Marshal(tc.Arguments)
- argsPreview := truncate(string(argsJSON), 200)
- logger.InfoCF("agent", logger.C0172,
- map[string]interface{}{
- "tool": tc.Name,
- "args": argsPreview,
- "iteration": iteration,
- })
-
- execArgs := withToolContextArgs(tc.Name, tc.Arguments, msg.Channel, msg.ChatID)
- result, err := al.executeToolCall(ctx, tc.Name, execArgs, msg.Channel, msg.ChatID)
- if err != nil {
- result = fmt.Sprintf("Error: %v", err)
- }
- if len(lastToolOutputs) < 4 {
- lastToolOutputs = append(lastToolOutputs, fmt.Sprintf("%s: %s", tc.Name, truncate(strings.ReplaceAll(result, "\n", " "), 180)))
- }
- toolResultMsg := providers.Message{
- Role: "tool",
- Content: result,
- ToolCallID: tc.ID,
- }
- messages = append(messages, toolResultMsg)
- // Persist tool result message.
- al.sessions.AddMessageFull(msg.SessionKey, toolResultMsg)
- }
- }
-
- if finalContent == "" && hasToolActivity {
- forced, ferr := al.provider.Chat(ctx, messages, nil, al.model, map[string]interface{}{"max_tokens": 8192, "temperature": 0.2})
- if ferr == nil && forced != nil && forced.Content != "" {
- finalContent = forced.Content
- }
- }
-
- // Filter out ... content from user-facing response
- // Keep full content in debug logs if needed, but remove from final output
- re := regexp.MustCompile(`(?s).*?`)
- userContent := re.ReplaceAllString(finalContent, "")
- if userContent == "" && finalContent != "" {
- // If only thoughts were present, maybe provide a generic "Done" or keep something?
- // For now, let's assume thoughts are auxiliary and empty response is okay if tools did work.
- // If no tools ran and only thoughts, user might be confused.
- if iteration == 1 {
- userContent = "Thinking process completed."
- }
- }
-
- al.sessions.AddMessage(msg.SessionKey, "user", msg.Content)
-
- // Persist full assistant response (including reasoning/tool flow outcomes when present).
- al.sessions.AddMessageFull(msg.SessionKey, providers.Message{
- Role: "assistant",
- Content: userContent,
+ loopResult, err := al.runLLMTurnLoop(llmTurnLoopConfig{
+ ctx: ctx,
+ triggerMsg: msg,
+ sessionKey: msg.SessionKey,
+ toolChannel: msg.Channel,
+ toolChatID: msg.ChatID,
+ messages: messages,
+ media: msg.Media,
+ mediaItems: msg.MediaItems,
+ enableStreaming: msg.Channel == "telegram" && al.telegramStreaming,
+ errorLogCode: logger.C0155,
+ logDirectResponse: true,
})
- al.sessions.SetLastLanguage(msg.SessionKey, responseLang)
- al.compactSessionIfNeeded(msg.SessionKey)
-
- al.sessions.Save(al.sessions.GetOrCreate(msg.SessionKey))
+ if err != nil {
+ al.reopenSpecTaskOnError(specTaskRef, msg, err)
+ return "", err
+ }
+ finalContent, userContent := al.finalizeUserTurnResponse(ctx, msg, responseLang, loopResult)
// Log response preview (original content)
responsePreview := truncate(finalContent, 120)
@@ -1163,19 +1423,12 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
"channel": msg.Channel,
"sender_id": msg.SenderID,
"preview": responsePreview,
- "iterations": iteration,
+ "iterations": loopResult.iteration,
"final_length": len(finalContent),
"user_length": len(userContent),
})
- if specTaskRef.Summary != "" {
- if err := al.maybeCompleteSpecCodingTask(specTaskRef, userContent); err != nil {
- logger.WarnCF("agent", logger.C0172, map[string]interface{}{
- "session_key": msg.SessionKey,
- "error": err.Error(),
- })
- }
- }
+ al.completeSpecTaskOnSuccess(specTaskRef, msg, userContent)
al.appendDailySummaryLog(msg, userContent)
return userContent, nil
}
@@ -1327,130 +1580,30 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
// Use the origin session for context
sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID)
- // Build messages with the announce content
- history := al.sessions.GetHistory(sessionKey)
- summary := al.sessions.GetSummary(sessionKey)
- preferredLang, lastLang := al.sessions.GetLanguagePreferences(sessionKey)
- responseLang := DetectResponseLanguage(msg.Content, preferredLang, lastLang)
- messages := al.contextBuilder.BuildMessages(
- history,
- summary,
- msg.Content,
- nil,
- originChannel,
- originChatID,
- responseLang,
- )
+ messages, responseLang := al.prepareSystemMessageContext(sessionKey, msg, originChannel, originChatID)
- iteration := 0
- var finalContent string
-
- for iteration < al.maxIterations {
- iteration++
-
- toolDefs := al.filteredToolDefinitionsForContext(ctx)
- providerToolDefs := al.buildProviderToolDefs(toolDefs)
-
- // Log LLM request details
- logger.DebugCF("agent", logger.C0152,
- map[string]interface{}{
- "iteration": iteration,
- "model": al.model,
- "messages_count": len(messages),
- "tools_count": len(providerToolDefs),
- "max_tokens": 8192,
- "temperature": 0.7,
- "system_prompt_len": len(messages[0].Content),
- })
-
- // Log full messages (detailed)
- logger.DebugCF("agent", logger.C0153,
- map[string]interface{}{
- "iteration": iteration,
- "messages_json": formatMessagesForLog(messages),
- "tools_json": formatToolsForLog(providerToolDefs),
- })
-
- options := al.buildResponsesOptions(sessionKey, 8192, 0.7)
- response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, options)
-
- if err != nil {
- if fb, _, ferr := al.tryFallbackProviders(ctx, msg, messages, providerToolDefs, options, err); ferr == nil && fb != nil {
- response = fb
- err = nil
- } else {
- err = ferr
- }
- }
- if err != nil {
- logger.ErrorCF("agent", logger.C0162,
- map[string]interface{}{
- "iteration": iteration,
- "error": err.Error(),
- })
- return "", fmt.Errorf("LLM call failed: %w", err)
- }
-
- if len(response.ToolCalls) == 0 {
- finalContent = response.Content
- break
- }
-
- assistantMsg := providers.Message{
- Role: "assistant",
- Content: response.Content,
- }
-
- for _, tc := range response.ToolCalls {
- argumentsJSON, _ := json.Marshal(tc.Arguments)
- assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
- ID: tc.ID,
- Type: "function",
- Function: &providers.FunctionCall{
- Name: tc.Name,
- Arguments: string(argumentsJSON),
- },
- })
- }
- messages = append(messages, assistantMsg)
- // Persist assistant message with tool calls.
- al.sessions.AddMessageFull(sessionKey, assistantMsg)
-
- for _, tc := range response.ToolCalls {
- execArgs := withToolContextArgs(tc.Name, tc.Arguments, originChannel, originChatID)
- result, err := al.executeToolCall(ctx, tc.Name, execArgs, originChannel, originChatID)
- if err != nil {
- result = fmt.Sprintf("Error: %v", err)
- }
-
- toolResultMsg := providers.Message{
- Role: "tool",
- Content: result,
- ToolCallID: tc.ID,
- }
- messages = append(messages, toolResultMsg)
- // Persist tool result message.
- al.sessions.AddMessageFull(sessionKey, toolResultMsg)
- }
+ loopResult, err := al.runLLMTurnLoop(llmTurnLoopConfig{
+ ctx: ctx,
+ triggerMsg: msg,
+ sessionKey: sessionKey,
+ toolChannel: originChannel,
+ toolChatID: originChatID,
+ messages: messages,
+ errorLogCode: logger.C0162,
+ logDirectResponse: false,
+ })
+ if err != nil {
+ return "", err
}
+ iteration := loopResult.iteration
+ finalContent := loopResult.finalContent
+ pendingPersist := loopResult.pendingPersist
if finalContent == "" {
finalContent = "Background task completed."
}
- // Save to session with system message marker
- al.sessions.AddMessage(sessionKey, "user", fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content))
-
- // If finalContent has no tool calls (last LLM turn is direct text),
- // earlier steps were already persisted in-loop; this stores the final reply.
- al.sessions.AddMessageFull(sessionKey, providers.Message{
- Role: "assistant",
- Content: finalContent,
- })
- al.sessions.SetLastLanguage(sessionKey, responseLang)
- al.compactSessionIfNeeded(sessionKey)
-
- al.sessions.Save(al.sessions.GetOrCreate(sessionKey))
+ al.finalizeSystemMessage(sessionKey, responseLang, msg, pendingPersist, finalContent)
logger.InfoCF("agent", logger.C0163,
map[string]interface{}{
@@ -1564,16 +1717,30 @@ func filterToolDefinitionsByContext(ctx context.Context, toolDefs []map[string]i
}
func (al *AgentLoop) buildResponsesOptions(sessionKey string, maxTokens int64, temperature float64) map[string]interface{} {
+ providerName := strings.TrimSpace(al.getSessionProvider(sessionKey))
+ if providerName == "" && len(al.providerNames) > 0 {
+ providerName = al.providerNames[0]
+ }
+ return al.buildResponsesOptionsForProvider(sessionKey, providerName, maxTokens, temperature)
+}
+
+func (al *AgentLoop) buildResponsesOptionsForProvider(sessionKey, providerName string, maxTokens int64, temperature float64) map[string]interface{} {
+ if maxTokens <= 0 {
+ maxTokens = int64(al.maxTokensForProvider(providerName))
+ }
+ if math.IsNaN(temperature) {
+ temperature = al.temperatureForProvider(providerName)
+ }
options := map[string]interface{}{
"max_tokens": maxTokens,
"temperature": temperature,
}
- if strings.EqualFold(strings.TrimSpace(al.getSessionProvider(sessionKey)), "codex") {
+ if strings.EqualFold(strings.TrimSpace(providerName), "codex") {
if key := strings.TrimSpace(sessionKey); key != "" {
options["codex_execution_session"] = key
}
}
- responsesCfg := al.responsesConfigForSession(sessionKey)
+ responsesCfg := al.responsesConfigForProvider(providerName)
responseTools := make([]map[string]interface{}, 0, 2)
if responsesCfg.WebSearchEnabled {
webTool := map[string]interface{}{"type": "web_search"}
@@ -1604,6 +1771,60 @@ func (al *AgentLoop) buildResponsesOptions(sessionKey string, maxTokens int64, t
return options
}
+func (al *AgentLoop) maxTokensForProvider(name string) int {
+ if al == nil {
+ return 8192
+ }
+ providerName := strings.TrimSpace(name)
+ if providerName != "" {
+ if limit, ok := al.providerMaxTokens[providerName]; ok && limit > 0 {
+ return limit
+ }
+ }
+ if al.maxTokens > 0 {
+ return al.maxTokens
+ }
+ return 8192
+}
+
+func (al *AgentLoop) maxTokensForSession(sessionKey string) int {
+ if al == nil {
+ return 8192
+ }
+ name := strings.TrimSpace(al.getSessionProvider(sessionKey))
+ if name == "" && len(al.providerNames) > 0 {
+ name = al.providerNames[0]
+ }
+ return al.maxTokensForProvider(name)
+}
+
+func (al *AgentLoop) temperatureForProvider(name string) float64 {
+ if al == nil {
+ return 0.7
+ }
+ providerName := strings.TrimSpace(name)
+ if providerName != "" {
+ if value, ok := al.providerTemperatures[providerName]; ok && value != 0 {
+ return value
+ }
+ }
+ if al.temperature != 0 {
+ return al.temperature
+ }
+ return 0.7
+}
+
+func (al *AgentLoop) temperatureForSession(sessionKey string) float64 {
+ if al == nil {
+ return 0.7
+ }
+ name := strings.TrimSpace(al.getSessionProvider(sessionKey))
+ if name == "" && len(al.providerNames) > 0 {
+ name = al.providerNames[0]
+ }
+ return al.temperatureForProvider(name)
+}
+
func (al *AgentLoop) responsesConfigForSession(sessionKey string) config.ProviderResponsesConfig {
if al == nil {
return config.ProviderResponsesConfig{}
@@ -1612,6 +1833,13 @@ func (al *AgentLoop) responsesConfigForSession(sessionKey string) config.Provide
if name == "" && len(al.providerNames) > 0 {
name = al.providerNames[0]
}
+ return al.responsesConfigForProvider(name)
+}
+
+func (al *AgentLoop) responsesConfigForProvider(name string) config.ProviderResponsesConfig {
+ if al == nil {
+ return config.ProviderResponsesConfig{}
+ }
if name == "" {
return config.ProviderResponsesConfig{}
}
diff --git a/pkg/agent/loop_codex_options_test.go b/pkg/agent/loop_codex_options_test.go
index 5ca8b3d..517f302 100644
--- a/pkg/agent/loop_codex_options_test.go
+++ b/pkg/agent/loop_codex_options_test.go
@@ -1,6 +1,63 @@
package agent
-import "testing"
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/YspCoder/clawgo/pkg/bus"
+ "github.com/YspCoder/clawgo/pkg/config"
+ "github.com/YspCoder/clawgo/pkg/providers"
+)
+
+type fallbackTestProvider struct {
+ response *providers.LLMResponse
+ err error
+ options map[string]interface{}
+}
+
+func (p *fallbackTestProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
+ p.options = map[string]interface{}{}
+ for k, v := range options {
+ p.options[k] = v
+ }
+ if p.err != nil {
+ return nil, p.err
+ }
+ return p.response, nil
+}
+
+func (p *fallbackTestProvider) GetDefaultModel() string { return "fallback-model" }
+
+type sequenceProvider struct {
+ responses []*providers.LLMResponse
+ errs []error
+}
+
+func (p *sequenceProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
+ if len(p.responses) == 0 && len(p.errs) == 0 {
+ return &providers.LLMResponse{Content: "ok", FinishReason: "stop"}, nil
+ }
+ resp := (*providers.LLMResponse)(nil)
+ if len(p.responses) > 0 {
+ resp = p.responses[0]
+ p.responses = p.responses[1:]
+ }
+ var err error
+ if len(p.errs) > 0 {
+ err = p.errs[0]
+ p.errs = p.errs[1:]
+ }
+ if err != nil {
+ return nil, err
+ }
+ if resp == nil {
+ return &providers.LLMResponse{Content: "ok", FinishReason: "stop"}, nil
+ }
+ return resp, nil
+}
+
+func (p *sequenceProvider) GetDefaultModel() string { return "sequence-model" }
func TestBuildResponsesOptionsAddsCodexExecutionSession(t *testing.T) {
loop := &AgentLoop{
@@ -42,3 +99,159 @@ func TestSyncSessionDefaultProviderOverridesStaleSessionProvider(t *testing.T) {
t.Fatalf("expected stale session provider to be replaced with current default, got %q", got)
}
}
+
+func TestSyncSessionDefaultProviderKeepsKnownSessionProvider(t *testing.T) {
+ loop := &AgentLoop{
+ providerNames: []string{"openai", "claude"},
+ sessionProvider: map[string]string{
+ "chat-1": "claude",
+ },
+ }
+
+ loop.syncSessionDefaultProvider("chat-1")
+
+ if got := loop.getSessionProvider("chat-1"); got != "claude" {
+ t.Fatalf("expected valid session provider to be preserved, got %q", got)
+ }
+}
+
+func TestMaxTokensForSessionUsesProviderOverride(t *testing.T) {
+ loop := &AgentLoop{
+ maxTokens: 4096,
+ providerNames: []string{"openai"},
+ sessionProvider: map[string]string{
+ "chat-1": "claude",
+ },
+ providerMaxTokens: map[string]int{
+ "claude": 16384,
+ },
+ }
+
+ if got := loop.maxTokensForSession("chat-1"); got != 16384 {
+ t.Fatalf("expected provider max_tokens override, got %d", got)
+ }
+}
+
+func TestMaxTokensForSessionFallsBackToAgentDefault(t *testing.T) {
+ loop := &AgentLoop{
+ maxTokens: 4096,
+ providerNames: []string{"openai"},
+ sessionProvider: map[string]string{},
+ providerMaxTokens: map[string]int{},
+ }
+
+ if got := loop.maxTokensForSession("chat-1"); got != 4096 {
+ t.Fatalf("expected fallback to agent default max_tokens, got %d", got)
+ }
+}
+
+func TestTemperatureForSessionUsesProviderOverride(t *testing.T) {
+ loop := &AgentLoop{
+ temperature: 0.7,
+ providerNames: []string{"openai"},
+ sessionProvider: map[string]string{
+ "chat-1": "claude",
+ },
+ providerTemperatures: map[string]float64{
+ "claude": 0.15,
+ },
+ }
+
+ if got := loop.temperatureForSession("chat-1"); got != 0.15 {
+ t.Fatalf("expected provider temperature override, got %v", got)
+ }
+}
+
+func TestTemperatureForSessionFallsBackToAgentDefault(t *testing.T) {
+ loop := &AgentLoop{
+ temperature: 0.7,
+ providerNames: []string{"openai"},
+ sessionProvider: map[string]string{},
+ providerTemperatures: map[string]float64{},
+ }
+
+ if got := loop.temperatureForSession("chat-1"); got != 0.7 {
+ t.Fatalf("expected fallback to agent default temperature, got %v", got)
+ }
+}
+
+func TestTryFallbackProvidersUsesFallbackProviderOptionsAndPersistsSelection(t *testing.T) {
+ fallback := &fallbackTestProvider{
+ response: &providers.LLMResponse{Content: "fallback", FinishReason: "stop"},
+ }
+ loop := &AgentLoop{
+ maxTokens: 4096,
+ temperature: 0.7,
+ providerNames: []string{"openai", "claude"},
+ sessionProvider: map[string]string{"chat-1": "openai"},
+ providerPool: map[string]providers.LLMProvider{"claude": fallback},
+ providerChain: []providerCandidate{{name: "openai", model: "gpt-a"}, {name: "claude", model: "claude-b"}},
+ providerMaxTokens: map[string]int{"claude": 16384},
+ providerTemperatures: map[string]float64{"claude": 0.15},
+ providerResponses: map[string]config.ProviderResponsesConfig{
+ "claude": {WebSearchEnabled: true},
+ },
+ }
+
+ resp, providerName, err := loop.tryFallbackProviders(context.Background(), bus.InboundMessage{SessionKey: "chat-1"}, nil, nil, errors.New("primary failed"))
+ if err != nil {
+ t.Fatalf("expected fallback success, got %v", err)
+ }
+ if resp == nil || resp.Content != "fallback" {
+ t.Fatalf("unexpected fallback response: %#v", resp)
+ }
+ if providerName != "claude" {
+ t.Fatalf("expected provider claude, got %q", providerName)
+ }
+ if got := loop.getSessionProvider("chat-1"); got != "claude" {
+ t.Fatalf("expected session provider to switch to fallback provider, got %q", got)
+ }
+ if got := fallback.options["max_tokens"]; got != int64(16384) {
+ t.Fatalf("expected fallback max_tokens 16384, got %#v", got)
+ }
+ if got := fallback.options["temperature"]; got != 0.15 {
+ t.Fatalf("expected fallback temperature 0.15, got %#v", got)
+ }
+ if _, ok := fallback.options["responses_tools"]; !ok {
+ t.Fatalf("expected fallback responses_tools to be populated")
+ }
+}
+
+func TestProcessMessageDoesNotPersistPartialAssistantToolHistoryOnFailure(t *testing.T) {
+ cfg := config.DefaultConfig()
+ cfg.Agents.Defaults.Workspace = t.TempDir()
+ cfg.Agents.Defaults.MaxToolIterations = 2
+
+ provider := &sequenceProvider{
+ responses: []*providers.LLMResponse{
+ {
+ Content: "",
+ ToolCalls: []providers.ToolCall{
+ {ID: "tool-1", Name: "read_file", Arguments: map[string]interface{}{"path": "missing.txt"}},
+ },
+ FinishReason: "tool_calls",
+ },
+ },
+ errs: []error{nil, errors.New("second pass failed")},
+ }
+
+ loop := NewAgentLoop(cfg, bus.NewMessageBus(), provider, nil)
+ _, err := loop.processMessage(context.Background(), bus.InboundMessage{
+ Channel: "cli",
+ ChatID: "direct",
+ SenderID: "user",
+ SessionKey: "cli:direct",
+ Content: "read file",
+ })
+ if err == nil {
+ t.Fatalf("expected processMessage error")
+ }
+
+ history := loop.sessions.GetHistory("cli:direct")
+ if len(history) != 1 {
+ t.Fatalf("expected only user message persisted on failure, got %d entries: %#v", len(history), history)
+ }
+ if history[0].Role != "user" || history[0].Content != "read file" {
+ t.Fatalf("unexpected persisted history: %#v", history)
+ }
+}
diff --git a/pkg/api/server.go b/pkg/api/server.go
index 46aaa1c..fba097d 100644
--- a/pkg/api/server.go
+++ b/pkg/api/server.go
@@ -25,6 +25,7 @@ import (
"sync"
"time"
+ "github.com/YspCoder/clawgo/pkg/bus"
"github.com/YspCoder/clawgo/pkg/channels"
cfgpkg "github.com/YspCoder/clawgo/pkg/config"
"github.com/YspCoder/clawgo/pkg/providers"
@@ -50,6 +51,7 @@ type Server struct {
onConfigAfter func(forceRuntimeReload bool) error
onCron func(action string, args map[string]interface{}) (interface{}, error)
onToolsCatalog func() interface{}
+ messageBus *bus.MessageBus
weixinChannel *channels.WeixinChannel
oauthFlowMu sync.Mutex
oauthFlows map[string]*providers.OAuthPendingFlow
@@ -57,6 +59,15 @@ type Server struct {
extraRoutes map[string]http.Handler
eventSubsMu sync.Mutex
eventSubs map[*websocket.Conn]struct{}
+ draftMu sync.RWMutex
+ channelDrafts channelDraftStore
+}
+
+type channelDraftStore struct {
+ Weixin *cfgpkg.WeixinConfig
+ Telegram *cfgpkg.TelegramConfig
+ Feishu *cfgpkg.FeishuConfig
+ weixinRuntime *channels.WeixinChannel
}
func NewServer(host string, port int, token string) *Server {
@@ -86,6 +97,7 @@ func (s *Server) SetChatHandler(fn func(ctx context.Context, sessionKey, content
func (s *Server) SetChatHistoryHandler(fn func(sessionKey string) []map[string]interface{}) {
s.onChatHistory = fn
}
+func (s *Server) SetMessageBus(mb *bus.MessageBus) { s.messageBus = mb }
func (s *Server) SetConfigAfterHook(fn func(forceRuntimeReload bool) error) { s.onConfigAfter = fn }
func (s *Server) SetCronHandler(fn func(action string, args map[string]interface{}) (interface{}, error)) {
s.onCron = fn
@@ -117,6 +129,356 @@ func (s *Server) SetWeixinChannel(ch *channels.WeixinChannel) {
}
}
+func cloneWeixinConfig(cfg cfgpkg.WeixinConfig) cfgpkg.WeixinConfig {
+ cp := cfg
+ cp.AllowFrom = append([]string(nil), cfg.AllowFrom...)
+ cp.Accounts = append([]cfgpkg.WeixinAccountConfig(nil), cfg.Accounts...)
+ return cp
+}
+
+func cloneTelegramConfig(cfg cfgpkg.TelegramConfig) cfgpkg.TelegramConfig {
+ cp := cfg
+ cp.AllowFrom = append([]string(nil), cfg.AllowFrom...)
+ cp.AllowChats = append([]string(nil), cfg.AllowChats...)
+ return cp
+}
+
+func cloneFeishuConfig(cfg cfgpkg.FeishuConfig) cfgpkg.FeishuConfig {
+ cp := cfg
+ cp.AllowFrom = append([]string(nil), cfg.AllowFrom...)
+ cp.AllowChats = append([]string(nil), cfg.AllowChats...)
+ return cp
+}
+
+func validChannelDraftName(name string) bool {
+ switch strings.ToLower(strings.TrimSpace(name)) {
+ case "weixin", "telegram", "feishu":
+ return true
+ default:
+ return false
+ }
+}
+
+func decodeMergedJSON[T any](current T, raw json.RawMessage) (T, error) {
+ out := current
+ if len(raw) == 0 || string(raw) == "null" {
+ return out, nil
+ }
+ baseBytes, err := json.Marshal(current)
+ if err != nil {
+ return out, err
+ }
+ merged := map[string]interface{}{}
+ if err := json.Unmarshal(baseBytes, &merged); err != nil {
+ return out, err
+ }
+ patch := map[string]interface{}{}
+ if err := json.Unmarshal(raw, &patch); err != nil {
+ return out, err
+ }
+ merged = mergeJSONMap(merged, patch)
+ mergedBytes, err := json.Marshal(merged)
+ if err != nil {
+ return out, err
+ }
+ if err := json.Unmarshal(mergedBytes, &out); err != nil {
+ return out, err
+ }
+ return out, nil
+}
+
+func (s *Server) syncWeixinDraftLocked() {
+ if s.channelDrafts.Weixin == nil || s.channelDrafts.weixinRuntime == nil {
+ return
+ }
+ snapshot := s.channelDrafts.weixinRuntime.SnapshotConfig()
+ s.channelDrafts.Weixin = &snapshot
+}
+
+func (s *Server) replaceWeixinDraftRuntimeLocked(cfg *cfgpkg.WeixinConfig) error {
+ if s.channelDrafts.weixinRuntime != nil {
+ _ = s.channelDrafts.weixinRuntime.Stop(context.Background())
+ s.channelDrafts.weixinRuntime = nil
+ }
+ if cfg == nil || !cfg.Enabled {
+ return nil
+ }
+ if s.messageBus == nil {
+ return fmt.Errorf("message bus not configured")
+ }
+ ch, err := channels.NewWeixinChannel(cloneWeixinConfig(*cfg), s.messageBus)
+ if err != nil {
+ return err
+ }
+ if err := ch.Start(context.Background()); err != nil {
+ return err
+ }
+ s.channelDrafts.weixinRuntime = ch
+ return nil
+}
+
+func (s *Server) clearChannelDraftsLocked() {
+ if s.channelDrafts.weixinRuntime != nil {
+ _ = s.channelDrafts.weixinRuntime.Stop(context.Background())
+ }
+ s.channelDrafts = channelDraftStore{}
+}
+
+func (s *Server) clearChannelDrafts() {
+ s.draftMu.Lock()
+ defer s.draftMu.Unlock()
+ s.clearChannelDraftsLocked()
+}
+
+func (s *Server) effectiveWeixinRuntime(persisted cfgpkg.WeixinConfig) (cfgpkg.WeixinConfig, *channels.WeixinChannel, bool) {
+ s.draftMu.Lock()
+ defer s.draftMu.Unlock()
+ if s.channelDrafts.Weixin != nil {
+ s.syncWeixinDraftLocked()
+ effective := cloneWeixinConfig(*s.channelDrafts.Weixin)
+ return effective, s.channelDrafts.weixinRuntime, true
+ }
+ return cloneWeixinConfig(persisted), s.weixinChannel, false
+}
+
+func (s *Server) currentChannelDraftPayload(cfg *cfgpkg.Config, channel string) map[string]interface{} {
+ channel = strings.ToLower(strings.TrimSpace(channel))
+ payload := map[string]interface{}{
+ "ok": true,
+ "channel": channel,
+ }
+ s.draftMu.Lock()
+ defer s.draftMu.Unlock()
+ switch channel {
+ case "weixin":
+ persisted := cloneWeixinConfig(cfg.Channels.Weixin)
+ var draft interface{}
+ effective := persisted
+ dirty := s.channelDrafts.Weixin != nil
+ if dirty {
+ s.syncWeixinDraftLocked()
+ effective = cloneWeixinConfig(*s.channelDrafts.Weixin)
+ draft = effective
+ }
+ payload["persisted"] = persisted
+ payload["draft"] = draft
+ payload["effective"] = effective
+ payload["dirty"] = dirty
+ payload["runtime_enabled"] = s.channelDrafts.weixinRuntime != nil && s.channelDrafts.weixinRuntime.IsRunning()
+ case "telegram":
+ persisted := cloneTelegramConfig(cfg.Channels.Telegram)
+ var draft interface{}
+ effective := persisted
+ dirty := s.channelDrafts.Telegram != nil
+ if dirty {
+ effective = cloneTelegramConfig(*s.channelDrafts.Telegram)
+ draft = effective
+ }
+ payload["persisted"] = persisted
+ payload["draft"] = draft
+ payload["effective"] = effective
+ payload["dirty"] = dirty
+ case "feishu":
+ persisted := cloneFeishuConfig(cfg.Channels.Feishu)
+ var draft interface{}
+ effective := persisted
+ dirty := s.channelDrafts.Feishu != nil
+ if dirty {
+ effective = cloneFeishuConfig(*s.channelDrafts.Feishu)
+ draft = effective
+ }
+ payload["persisted"] = persisted
+ payload["draft"] = draft
+ payload["effective"] = effective
+ payload["dirty"] = dirty
+ }
+ return payload
+}
+
+func (s *Server) applyChannelDrafts(cfg *cfgpkg.Config) {
+ if cfg == nil {
+ return
+ }
+ s.draftMu.Lock()
+ defer s.draftMu.Unlock()
+ s.syncWeixinDraftLocked()
+ if s.channelDrafts.Weixin != nil {
+ cfg.Channels.Weixin = cloneWeixinConfig(*s.channelDrafts.Weixin)
+ }
+ if s.channelDrafts.Telegram != nil {
+ cfg.Channels.Telegram = cloneTelegramConfig(*s.channelDrafts.Telegram)
+ }
+ if s.channelDrafts.Feishu != nil {
+ cfg.Channels.Feishu = cloneFeishuConfig(*s.channelDrafts.Feishu)
+ }
+}
+
+func (s *Server) handleWebUIChannelDraft(w http.ResponseWriter, r *http.Request) {
+ if !s.checkAuth(r) {
+ http.Error(w, "unauthorized", http.StatusUnauthorized)
+ return
+ }
+ cfg, err := s.loadConfig()
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ switch r.Method {
+ case http.MethodGet:
+ channel := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("channel")))
+ if channel == "" {
+ writeJSON(w, map[string]interface{}{
+ "ok": true,
+ "channels": map[string]interface{}{
+ "weixin": s.currentChannelDraftPayload(cfg, "weixin"),
+ "telegram": s.currentChannelDraftPayload(cfg, "telegram"),
+ "feishu": s.currentChannelDraftPayload(cfg, "feishu"),
+ },
+ })
+ return
+ }
+ if !validChannelDraftName(channel) {
+ http.Error(w, "unsupported channel", http.StatusBadRequest)
+ return
+ }
+ writeJSON(w, s.currentChannelDraftPayload(cfg, channel))
+ case http.MethodPost:
+ var body struct {
+ Channel string `json:"channel"`
+ Config json.RawMessage `json:"config"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
+ http.Error(w, "invalid json", http.StatusBadRequest)
+ return
+ }
+ channel := strings.ToLower(strings.TrimSpace(body.Channel))
+ if !validChannelDraftName(channel) {
+ http.Error(w, "unsupported channel", http.StatusBadRequest)
+ return
+ }
+ s.draftMu.Lock()
+ switch channel {
+ case "weixin":
+ current := cfg.Channels.Weixin
+ if s.channelDrafts.Weixin != nil {
+ s.syncWeixinDraftLocked()
+ current = cloneWeixinConfig(*s.channelDrafts.Weixin)
+ }
+ next, err := decodeMergedJSON(current, body.Config)
+ if err != nil {
+ s.draftMu.Unlock()
+ http.Error(w, "invalid weixin config", http.StatusBadRequest)
+ return
+ }
+ next = cloneWeixinConfig(next)
+ if err := s.replaceWeixinDraftRuntimeLocked(&next); err != nil {
+ s.draftMu.Unlock()
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+ s.channelDrafts.Weixin = &next
+ case "telegram":
+ current := cfg.Channels.Telegram
+ if s.channelDrafts.Telegram != nil {
+ current = cloneTelegramConfig(*s.channelDrafts.Telegram)
+ }
+ next, err := decodeMergedJSON(current, body.Config)
+ if err != nil {
+ s.draftMu.Unlock()
+ http.Error(w, "invalid telegram config", http.StatusBadRequest)
+ return
+ }
+ next = cloneTelegramConfig(next)
+ s.channelDrafts.Telegram = &next
+ case "feishu":
+ current := cfg.Channels.Feishu
+ if s.channelDrafts.Feishu != nil {
+ current = cloneFeishuConfig(*s.channelDrafts.Feishu)
+ }
+ next, err := decodeMergedJSON(current, body.Config)
+ if err != nil {
+ s.draftMu.Unlock()
+ http.Error(w, "invalid feishu config", http.StatusBadRequest)
+ return
+ }
+ next = cloneFeishuConfig(next)
+ s.channelDrafts.Feishu = &next
+ }
+ s.draftMu.Unlock()
+ s.broadcastEvent(map[string]interface{}{
+ "type": "channel_draft_changed",
+ "channel": channel,
+ })
+ writeJSON(w, s.currentChannelDraftPayload(cfg, channel))
+ case http.MethodDelete:
+ channel := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("channel")))
+ s.draftMu.Lock()
+ if channel == "" {
+ s.clearChannelDraftsLocked()
+ s.draftMu.Unlock()
+ writeJSON(w, map[string]interface{}{"ok": true, "cleared": "all"})
+ return
+ }
+ if !validChannelDraftName(channel) {
+ s.draftMu.Unlock()
+ http.Error(w, "unsupported channel", http.StatusBadRequest)
+ return
+ }
+ switch channel {
+ case "weixin":
+ if s.channelDrafts.weixinRuntime != nil {
+ _ = s.channelDrafts.weixinRuntime.Stop(context.Background())
+ s.channelDrafts.weixinRuntime = nil
+ }
+ s.channelDrafts.Weixin = nil
+ case "telegram":
+ s.channelDrafts.Telegram = nil
+ case "feishu":
+ s.channelDrafts.Feishu = nil
+ }
+ s.draftMu.Unlock()
+ s.broadcastEvent(map[string]interface{}{
+ "type": "channel_draft_changed",
+ "channel": channel,
+ })
+ writeJSON(w, map[string]interface{}{"ok": true, "channel": channel, "cleared": true})
+ default:
+ http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
+ }
+}
+
+func (s *Server) handleWebUIChannelDraftCommit(w http.ResponseWriter, r *http.Request) {
+ if !s.checkAuth(r) {
+ http.Error(w, "unauthorized", http.StatusUnauthorized)
+ return
+ }
+ if r.Method != http.MethodPost {
+ http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+ cfg, err := s.loadConfig()
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ s.applyChannelDrafts(cfg)
+ if err := s.persistWebUIConfig(cfg); err != nil {
+ var validationErr *configValidationError
+ if errors.As(err, &validationErr) {
+ writeJSONStatus(w, http.StatusBadRequest, map[string]interface{}{
+ "ok": false,
+ "error": validationErr.Error(),
+ "errors": validationErr.Fields,
+ })
+ return
+ }
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+ s.clearChannelDrafts()
+ writeJSON(w, map[string]interface{}{"ok": true, "committed": true})
+}
+
func (s *Server) handleWebUIEventsLive(w http.ResponseWriter, r *http.Request) {
if !s.checkAuth(r) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
@@ -225,11 +587,14 @@ func (s *Server) Start(ctx context.Context) error {
mux.HandleFunc("/api/sessions", s.handleWebUISessions)
mux.HandleFunc("/api/memory", s.handleWebUIMemory)
mux.HandleFunc("/api/workspace_file", s.handleWebUIWorkspaceFile)
+ mux.HandleFunc("/api/workspace_docs", s.handleWebUIWorkspaceDocs)
mux.HandleFunc("/api/tool_allowlist_groups", s.handleWebUIToolAllowlistGroups)
mux.HandleFunc("/api/tools", s.handleWebUITools)
mux.HandleFunc("/api/mcp/install", s.handleWebUIMCPInstall)
mux.HandleFunc("/api/logs/live", s.handleWebUILogsLive)
mux.HandleFunc("/api/logs/recent", s.handleWebUILogsRecent)
+ mux.HandleFunc("/api/channels/draft", s.handleWebUIChannelDraft)
+ mux.HandleFunc("/api/channels/draft/commit", s.handleWebUIChannelDraftCommit)
s.extraRoutesMu.RLock()
for path, handler := range s.extraRoutes {
routePath := path
@@ -1227,11 +1592,17 @@ func (s *Server) handleWebUIWeixinLoginStart(w http.ResponseWriter, r *http.Requ
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
- if s.weixinChannel == nil {
+ cfg, err := s.loadConfig()
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ _, ch, _ := s.effectiveWeixinRuntime(cfg.Channels.Weixin)
+ if ch == nil {
http.Error(w, "weixin channel unavailable", http.StatusServiceUnavailable)
return
}
- if _, err := s.weixinChannel.StartLogin(r.Context()); err != nil {
+ if _, err := ch.StartLogin(r.Context()); err != nil {
http.Error(w, err.Error(), http.StatusBadGateway)
return
}
@@ -1248,7 +1619,13 @@ func (s *Server) handleWebUIWeixinLoginCancel(w http.ResponseWriter, r *http.Req
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
- if s.weixinChannel == nil {
+ cfg, err := s.loadConfig()
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ _, ch, _ := s.effectiveWeixinRuntime(cfg.Channels.Weixin)
+ if ch == nil {
http.Error(w, "weixin channel unavailable", http.StatusServiceUnavailable)
return
}
@@ -1259,7 +1636,7 @@ func (s *Server) handleWebUIWeixinLoginCancel(w http.ResponseWriter, r *http.Req
http.Error(w, "invalid json body", http.StatusBadRequest)
return
}
- if !s.weixinChannel.CancelPendingLogin(body.LoginID) {
+ if !ch.CancelPendingLogin(body.LoginID) {
http.Error(w, "login_id not found", http.StatusNotFound)
return
}
@@ -1283,8 +1660,14 @@ func (s *Server) handleWebUIWeixinQR(w http.ResponseWriter, r *http.Request) {
}
qrCode := ""
loginID := strings.TrimSpace(r.URL.Query().Get("login_id"))
- if loginID != "" && s.weixinChannel != nil {
- if pending := s.weixinChannel.PendingLoginByID(loginID); pending != nil {
+ cfg, err := s.loadConfig()
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ _, ch, _ := s.effectiveWeixinRuntime(cfg.Channels.Weixin)
+ if loginID != "" && ch != nil {
+ if pending := ch.PendingLoginByID(loginID); pending != nil {
qrCode = fallbackString(pending.QRCodeImgContent, pending.QRCode)
}
}
@@ -1318,7 +1701,13 @@ func (s *Server) handleWebUIWeixinAccountRemove(w http.ResponseWriter, r *http.R
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
- if s.weixinChannel == nil {
+ cfg, err := s.loadConfig()
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ _, ch, _ := s.effectiveWeixinRuntime(cfg.Channels.Weixin)
+ if ch == nil {
http.Error(w, "weixin channel unavailable", http.StatusServiceUnavailable)
return
}
@@ -1329,7 +1718,7 @@ func (s *Server) handleWebUIWeixinAccountRemove(w http.ResponseWriter, r *http.R
http.Error(w, "invalid json body", http.StatusBadRequest)
return
}
- if err := s.weixinChannel.RemoveAccount(body.BotID); err != nil {
+ if err := ch.RemoveAccount(body.BotID); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
@@ -1346,7 +1735,13 @@ func (s *Server) handleWebUIWeixinAccountDefault(w http.ResponseWriter, r *http.
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
- if s.weixinChannel == nil {
+ cfg, err := s.loadConfig()
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ _, ch, _ := s.effectiveWeixinRuntime(cfg.Channels.Weixin)
+ if ch == nil {
http.Error(w, "weixin channel unavailable", http.StatusServiceUnavailable)
return
}
@@ -1357,7 +1752,7 @@ func (s *Server) handleWebUIWeixinAccountDefault(w http.ResponseWriter, r *http.
http.Error(w, "invalid json body", http.StatusBadRequest)
return
}
- if err := s.weixinChannel.SetDefaultAccount(body.BotID); err != nil {
+ if err := ch.SetDefaultAccount(body.BotID); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
@@ -1373,25 +1768,35 @@ func (s *Server) webUIWeixinStatusPayload(ctx context.Context) (map[string]inter
"error": err.Error(),
}, http.StatusInternalServerError
}
- weixinCfg := cfg.Channels.Weixin
- if s.weixinChannel == nil {
+ persistedCfg := cloneWeixinConfig(cfg.Channels.Weixin)
+ weixinCfg, ch, usingDraft := s.effectiveWeixinRuntime(persistedCfg)
+ if ch == nil {
return map[string]interface{}{
- "ok": false,
- "enabled": weixinCfg.Enabled,
- "base_url": weixinCfg.BaseURL,
- "error": "weixin channel unavailable",
+ "ok": false,
+ "enabled": weixinCfg.Enabled,
+ "config_enabled": persistedCfg.Enabled,
+ "runtime_enabled": false,
+ "draft_dirty": usingDraft,
+ "base_url": weixinCfg.BaseURL,
+ "error": "weixin channel unavailable",
}, http.StatusOK
}
- pendingLogins, err := s.weixinChannel.RefreshLoginStatuses(ctx)
+ pendingLogins, err := ch.RefreshLoginStatuses(ctx)
if err != nil {
return map[string]interface{}{
- "ok": false,
- "enabled": weixinCfg.Enabled,
- "base_url": weixinCfg.BaseURL,
- "error": err.Error(),
+ "ok": false,
+ "enabled": weixinCfg.Enabled,
+ "config_enabled": persistedCfg.Enabled,
+ "runtime_enabled": ch.IsRunning(),
+ "draft_dirty": usingDraft,
+ "base_url": weixinCfg.BaseURL,
+ "error": err.Error(),
}, http.StatusOK
}
- accounts := s.weixinChannel.ListAccounts()
+ if usingDraft {
+ weixinCfg = ch.SnapshotConfig()
+ }
+ accounts := ch.ListAccounts()
pendingPayload := make([]map[string]interface{}, 0, len(pendingLogins))
for _, pending := range pendingLogins {
pendingPayload = append(pendingPayload, map[string]interface{}{
@@ -1409,10 +1814,13 @@ func (s *Server) webUIWeixinStatusPayload(ctx context.Context) (map[string]inter
firstPending = pendingLogins[0]
}
return map[string]interface{}{
- "ok": true,
- "enabled": weixinCfg.Enabled,
- "base_url": fallbackString(weixinCfg.BaseURL, "https://ilinkai.weixin.qq.com"),
- "pending_logins": pendingPayload,
+ "ok": true,
+ "enabled": weixinCfg.Enabled,
+ "config_enabled": persistedCfg.Enabled,
+ "runtime_enabled": ch.IsRunning(),
+ "draft_dirty": usingDraft,
+ "base_url": fallbackString(weixinCfg.BaseURL, "https://ilinkai.weixin.qq.com"),
+ "pending_logins": pendingPayload,
"pending_login": map[string]interface{}{
"login_id": pendingString(firstPending, "login_id"),
"qr_code": pendingString(firstPending, "qr_code"),
@@ -2976,9 +3384,6 @@ func (s *Server) handleWebUIMemory(w http.ResponseWriter, r *http.Request) {
path := strings.TrimSpace(r.URL.Query().Get("path"))
if path == "" {
files := make([]string, 0, 16)
- if _, err := os.Stat(filepath.Join(s.workspacePath, "MEMORY.md")); err == nil {
- files = append(files, "MEMORY.md")
- }
entries, err := os.ReadDir(memoryDir)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -2993,11 +3398,7 @@ func (s *Server) handleWebUIMemory(w http.ResponseWriter, r *http.Request) {
writeJSON(w, map[string]interface{}{"ok": true, "files": files})
return
}
- baseDir := memoryDir
- if strings.EqualFold(path, "MEMORY.md") {
- baseDir = strings.TrimSpace(s.workspacePath)
- }
- clean, content, found, err := readRelativeTextFile(baseDir, path)
+ clean, content, found, err := readRelativeTextFile(memoryDir, path)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@@ -3073,6 +3474,70 @@ func (s *Server) handleWebUIWorkspaceFile(w http.ResponseWriter, r *http.Request
}
}
+var workspaceDocFiles = []string{
+ "AGENTS.md",
+ "BOOT.md",
+ "BOOTSTRAP.md",
+ "HEARTBEAT.md",
+ "IDENTITY.md",
+ "MEMORY.md",
+ "SOUL.md",
+ "TOOLS.md",
+ "USER.md",
+}
+
+func (s *Server) handleWebUIWorkspaceDocs(w http.ResponseWriter, r *http.Request) {
+ if !s.checkAuth(r) {
+ http.Error(w, "unauthorized", http.StatusUnauthorized)
+ return
+ }
+ if r.Method != http.MethodGet {
+ http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+ workspace := strings.TrimSpace(s.workspacePath)
+ path := strings.TrimSpace(r.URL.Query().Get("path"))
+ if path != "" {
+ if !isWorkspaceDocAllowed(path) {
+ http.Error(w, "invalid path", http.StatusBadRequest)
+ return
+ }
+ clean, content, found, err := readRelativeTextFile(workspace, path)
+ if err != nil {
+ http.Error(w, err.Error(), relativeFilePathStatus(err))
+ return
+ }
+ if !found {
+ http.Error(w, os.ErrNotExist.Error(), http.StatusInternalServerError)
+ return
+ }
+ writeJSON(w, map[string]interface{}{"ok": true, "path": clean, "content": content})
+ return
+ }
+ files := make([]string, 0, len(workspaceDocFiles))
+ for _, name := range workspaceDocFiles {
+ _, _, found, err := readRelativeTextFile(workspace, name)
+ if err != nil {
+ http.Error(w, err.Error(), relativeFilePathStatus(err))
+ return
+ }
+ if !found {
+ continue
+ }
+ files = append(files, name)
+ }
+ writeJSON(w, map[string]interface{}{"ok": true, "files": files})
+}
+
+func isWorkspaceDocAllowed(name string) bool {
+ for _, allowed := range workspaceDocFiles {
+ if name == allowed {
+ return true
+ }
+ }
+ return false
+}
+
func (s *Server) handleWebUILogsRecent(w http.ResponseWriter, r *http.Request) {
if !s.checkAuth(r) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
diff --git a/pkg/api/server_test.go b/pkg/api/server_test.go
index e8191e0..ec7642e 100644
--- a/pkg/api/server_test.go
+++ b/pkg/api/server_test.go
@@ -12,6 +12,7 @@ import (
"testing"
"time"
+ "github.com/YspCoder/clawgo/pkg/bus"
cfgpkg "github.com/YspCoder/clawgo/pkg/config"
"github.com/gorilla/websocket"
)
@@ -139,6 +140,100 @@ func TestHandleWebUIConfigPostSavesNormalizedConfig(t *testing.T) {
}
}
+func TestHandleWebUIChannelDraftCommitPersistsDrafts(t *testing.T) {
+ t.Parallel()
+
+ tmp := t.TempDir()
+ cfgPath := filepath.Join(tmp, "config.json")
+ cfg := cfgpkg.DefaultConfig()
+ if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil {
+ t.Fatalf("save config: %v", err)
+ }
+
+ srv := NewServer("127.0.0.1", 0, "")
+ srv.SetConfigPath(cfgPath)
+ srv.SetMessageBus(bus.NewMessageBus())
+ hookCalled := 0
+ srv.SetConfigAfterHook(func(forceRuntimeReload bool) error {
+ hookCalled++
+ return nil
+ })
+
+ draftReq := httptest.NewRequest(http.MethodPost, "/api/channels/draft", strings.NewReader(`{"channel":"telegram","config":{"enabled":true,"token":"bot-token","streaming":true}}`))
+ draftReq.Header.Set("Content-Type", "application/json")
+ draftRec := httptest.NewRecorder()
+ srv.handleWebUIChannelDraft(draftRec, draftReq)
+ if draftRec.Code != http.StatusOK {
+ t.Fatalf("expected 200 from draft save, got %d: %s", draftRec.Code, draftRec.Body.String())
+ }
+
+ commitReq := httptest.NewRequest(http.MethodPost, "/api/channels/draft/commit", nil)
+ commitRec := httptest.NewRecorder()
+ srv.handleWebUIChannelDraftCommit(commitRec, commitReq)
+ if commitRec.Code != http.StatusOK {
+ t.Fatalf("expected 200 from draft commit, got %d: %s", commitRec.Code, commitRec.Body.String())
+ }
+ if hookCalled != 1 {
+ t.Fatalf("expected reload hook once, got %d", hookCalled)
+ }
+
+ updated, err := cfgpkg.LoadConfig(cfgPath)
+ if err != nil {
+ t.Fatalf("reload config: %v", err)
+ }
+ if !updated.Channels.Telegram.Enabled {
+ t.Fatalf("expected telegram enabled after draft commit")
+ }
+ if updated.Channels.Telegram.Token != "bot-token" {
+ t.Fatalf("expected telegram token to persist, got %q", updated.Channels.Telegram.Token)
+ }
+}
+
+func TestHandleWebUIWeixinStatusReflectsDraftRuntime(t *testing.T) {
+ t.Parallel()
+
+ tmp := t.TempDir()
+ cfgPath := filepath.Join(tmp, "config.json")
+ cfg := cfgpkg.DefaultConfig()
+ cfg.Channels.Weixin.Enabled = false
+ if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil {
+ t.Fatalf("save config: %v", err)
+ }
+
+ srv := NewServer("127.0.0.1", 0, "")
+ srv.SetConfigPath(cfgPath)
+ srv.SetMessageBus(bus.NewMessageBus())
+
+ draftReq := httptest.NewRequest(http.MethodPost, "/api/channels/draft", strings.NewReader(`{"channel":"weixin","config":{"enabled":true,"base_url":"https://ilinkai.weixin.qq.com"}}`))
+ draftReq.Header.Set("Content-Type", "application/json")
+ draftRec := httptest.NewRecorder()
+ srv.handleWebUIChannelDraft(draftRec, draftReq)
+ if draftRec.Code != http.StatusOK {
+ t.Fatalf("expected 200 from weixin draft save, got %d: %s", draftRec.Code, draftRec.Body.String())
+ }
+
+ statusReq := httptest.NewRequest(http.MethodGet, "/api/weixin/status", nil)
+ statusRec := httptest.NewRecorder()
+ srv.handleWebUIWeixinStatus(statusRec, statusReq)
+ if statusRec.Code != http.StatusOK {
+ t.Fatalf("expected 200 from status, got %d: %s", statusRec.Code, statusRec.Body.String())
+ }
+
+ var payload map[string]interface{}
+ if err := json.Unmarshal(statusRec.Body.Bytes(), &payload); err != nil {
+ t.Fatalf("decode status: %v", err)
+ }
+ if payload["draft_dirty"] != true {
+ t.Fatalf("expected draft_dirty=true, got %#v", payload["draft_dirty"])
+ }
+ if payload["config_enabled"] != false {
+ t.Fatalf("expected config_enabled=false, got %#v", payload["config_enabled"])
+ }
+ if payload["runtime_enabled"] != true {
+ t.Fatalf("expected runtime_enabled=true, got %#v", payload["runtime_enabled"])
+ }
+}
+
func TestWithCORSEchoesPreflightHeaders(t *testing.T) {
t.Parallel()
@@ -255,13 +350,10 @@ func TestSaveProviderConfigForcesRuntimeReload(t *testing.T) {
}
}
-func TestHandleWebUIMemoryListsAndReadsWorkspaceMemoryFile(t *testing.T) {
+func TestHandleWebUIMemoryListsAndReadsMemoryDirFile(t *testing.T) {
t.Parallel()
tmp := t.TempDir()
- if err := os.WriteFile(filepath.Join(tmp, "MEMORY.md"), []byte("# long-term\n"), 0o644); err != nil {
- t.Fatalf("write workspace memory: %v", err)
- }
if err := os.MkdirAll(filepath.Join(tmp, "memory"), 0o755); err != nil {
t.Fatalf("mkdir memory dir: %v", err)
}
@@ -285,11 +377,11 @@ func TestHandleWebUIMemoryListsAndReadsWorkspaceMemoryFile(t *testing.T) {
if err := json.Unmarshal(listRec.Body.Bytes(), &listPayload); err != nil {
t.Fatalf("decode list payload: %v", err)
}
- if len(listPayload.Files) < 2 || listPayload.Files[0] != "MEMORY.md" {
- t.Fatalf("expected MEMORY.md in memory file list, got %+v", listPayload.Files)
+ if len(listPayload.Files) != 1 || listPayload.Files[0] != "2026-03-19.md" {
+ t.Fatalf("expected only memory dir files, got %+v", listPayload.Files)
}
- readReq := httptest.NewRequest(http.MethodGet, "/api/memory?path=MEMORY.md", nil)
+ readReq := httptest.NewRequest(http.MethodGet, "/api/memory?path=2026-03-19.md", nil)
readRec := httptest.NewRecorder()
srv.handleWebUIMemory(readRec, readReq)
if readRec.Code != http.StatusOK {
@@ -303,11 +395,71 @@ func TestHandleWebUIMemoryListsAndReadsWorkspaceMemoryFile(t *testing.T) {
if err := json.Unmarshal(readRec.Body.Bytes(), &readPayload); err != nil {
t.Fatalf("decode read payload: %v", err)
}
- if readPayload.Path != "MEMORY.md" || readPayload.Content != "# long-term\n" {
+ if readPayload.Path != "2026-03-19.md" || readPayload.Content != "daily\n" {
t.Fatalf("unexpected memory payload: %+v", readPayload)
}
}
+func TestHandleWebUIWorkspaceDocsListAndRead(t *testing.T) {
+ t.Parallel()
+
+ tmp := t.TempDir()
+ if err := os.WriteFile(filepath.Join(tmp, "AGENTS.md"), []byte("agents\n"), 0o644); err != nil {
+ t.Fatalf("write AGENTS.md: %v", err)
+ }
+ if err := os.WriteFile(filepath.Join(tmp, "MEMORY.md"), []byte("memory\n"), 0o644); err != nil {
+ t.Fatalf("write MEMORY.md: %v", err)
+ }
+
+ srv := NewServer("127.0.0.1", 0, "")
+ srv.SetWorkspacePath(tmp)
+
+ req := httptest.NewRequest(http.MethodGet, "/api/workspace_docs", nil)
+ rec := httptest.NewRecorder()
+ srv.handleWebUIWorkspaceDocs(rec, req)
+ if rec.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", rec.Code, rec.Body.String())
+ }
+
+ var payload struct {
+ OK bool `json:"ok"`
+ Files []string `json:"files"`
+ }
+ if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
+ t.Fatalf("decode payload: %v", err)
+ }
+ if !payload.OK {
+ t.Fatalf("expected ok=true, got %+v", payload)
+ }
+ if len(payload.Files) != 2 {
+ t.Fatalf("expected 2 existing docs, got %+v", payload.Files)
+ }
+ if payload.Files[0] != "AGENTS.md" {
+ t.Fatalf("unexpected first doc payload: %+v", payload.Files[0])
+ }
+ if payload.Files[1] != "MEMORY.md" {
+ t.Fatalf("unexpected second doc payload: %+v", payload.Files[1])
+ }
+
+ readReq := httptest.NewRequest(http.MethodGet, "/api/workspace_docs?path=AGENTS.md", nil)
+ readRec := httptest.NewRecorder()
+ srv.handleWebUIWorkspaceDocs(readRec, readReq)
+ if readRec.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", readRec.Code, readRec.Body.String())
+ }
+ var readPayload struct {
+ OK bool `json:"ok"`
+ Path string `json:"path"`
+ Content string `json:"content"`
+ }
+ if err := json.Unmarshal(readRec.Body.Bytes(), &readPayload); err != nil {
+ t.Fatalf("decode read payload: %v", err)
+ }
+ if readPayload.Path != "AGENTS.md" || readPayload.Content != "agents\n" {
+ t.Fatalf("unexpected read payload: %+v", readPayload)
+ }
+}
+
func TestHandleWebUIChatLive(t *testing.T) {
t.Parallel()
diff --git a/pkg/channels/weixin.go b/pkg/channels/weixin.go
index c64aa90..0fba49b 100644
--- a/pkg/channels/weixin.go
+++ b/pkg/channels/weixin.go
@@ -389,6 +389,22 @@ func (c *WeixinChannel) ListAccounts() []WeixinAccountSnapshot {
return out
}
+func (c *WeixinChannel) SnapshotConfig() config.WeixinConfig {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ cfgCopy := c.config
+ cfgCopy.AllowFrom = append([]string(nil), cfgCopy.AllowFrom...)
+ cfgCopy.Accounts = append([]config.WeixinAccountConfig(nil), c.accountConfigsLocked()...)
+ cfgCopy.DefaultBotID = strings.TrimSpace(c.defaultBotIDLocked())
+ cfgCopy.BotID = ""
+ cfgCopy.BotToken = ""
+ cfgCopy.IlinkUserID = ""
+ cfgCopy.ContextToken = ""
+ cfgCopy.GetUpdatesBuf = ""
+ return cfgCopy
+}
+
func (c *WeixinChannel) SetDefaultAccount(botID string) error {
botID = strings.TrimSpace(botID)
if botID == "" {
diff --git a/pkg/config/config.go b/pkg/config/config.go
index 7267ddb..cd729af 100644
--- a/pkg/config/config.go
+++ b/pkg/config/config.go
@@ -226,6 +226,8 @@ type ProviderConfig struct {
APIKey string `json:"api_key" env:"CLAWGO_PROVIDERS_{{.Name}}_API_KEY"`
APIBase string `json:"api_base" env:"CLAWGO_PROVIDERS_{{.Name}}_API_BASE"`
Models []string `json:"models" env:"CLAWGO_PROVIDERS_{{.Name}}_MODELS"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
SupportsResponsesCompact bool `json:"supports_responses_compact" env:"CLAWGO_PROVIDERS_{{.Name}}_SUPPORTS_RESPONSES_COMPACT"`
Auth string `json:"auth" env:"CLAWGO_PROVIDERS_{{.Name}}_AUTH"`
TimeoutSec int `json:"timeout_sec" env:"CLAWGO_PROVIDERS_PROXY_TIMEOUT_SEC"`
diff --git a/pkg/config/normalized.go b/pkg/config/normalized.go
index 9871422..20fe9b3 100644
--- a/pkg/config/normalized.go
+++ b/pkg/config/normalized.go
@@ -53,6 +53,8 @@ type NormalizedRuntimeRouterConfig struct {
type NormalizedRuntimeProviderConfig struct {
Auth string `json:"auth,omitempty"`
APIBase string `json:"api_base,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
TimeoutSec int `json:"timeout_sec,omitempty"`
OAuth ProviderOAuthConfig `json:"oauth,omitempty"`
RuntimePersist bool `json:"runtime_persist,omitempty"`
@@ -143,6 +145,8 @@ func (c *Config) NormalizedView() NormalizedConfig {
view.Runtime.Providers[name] = NormalizedRuntimeProviderConfig{
Auth: pc.Auth,
APIBase: pc.APIBase,
+ MaxTokens: pc.MaxTokens,
+ Temperature: pc.Temperature,
TimeoutSec: pc.TimeoutSec,
OAuth: pc.OAuth,
RuntimePersist: pc.RuntimePersist,
@@ -232,6 +236,12 @@ func (c *Config) ApplyNormalizedView(view NormalizedConfig) {
current := c.Models.Providers[name]
current.Auth = strings.TrimSpace(item.Auth)
current.APIBase = strings.TrimSpace(item.APIBase)
+ if item.MaxTokens > 0 {
+ current.MaxTokens = item.MaxTokens
+ } else if item.MaxTokens == 0 {
+ current.MaxTokens = 0
+ }
+ current.Temperature = item.Temperature
if item.TimeoutSec > 0 {
current.TimeoutSec = item.TimeoutSec
}
diff --git a/pkg/config/normalized_test.go b/pkg/config/normalized_test.go
index 31e26a7..520e08b 100644
--- a/pkg/config/normalized_test.go
+++ b/pkg/config/normalized_test.go
@@ -5,6 +5,13 @@ import "testing"
func TestNormalizedViewProjectsCoreAndRuntime(t *testing.T) {
cfg := DefaultConfig()
cfg.Agents.Router.Enabled = true
+ cfg.Models.Providers["openai"] = ProviderConfig{
+ APIBase: "https://api.openai.com/v1",
+ Models: []string{"gpt-5.4"},
+ MaxTokens: 12288,
+ Temperature: 0.35,
+ TimeoutSec: 90,
+ }
cfg.Agents.Subagents["coder"] = SubagentConfig{
Enabled: true,
Role: "coding",
@@ -27,4 +34,10 @@ func TestNormalizedViewProjectsCoreAndRuntime(t *testing.T) {
if !view.Runtime.Router.Enabled || view.Runtime.Router.Strategy != "rules_first" {
t.Fatalf("unexpected runtime router: %+v", view.Runtime.Router)
}
+ if got := view.Runtime.Providers["openai"].MaxTokens; got != 12288 {
+ t.Fatalf("expected provider max_tokens in normalized runtime view, got %d", got)
+ }
+ if got := view.Runtime.Providers["openai"].Temperature; got != 0.35 {
+ t.Fatalf("expected provider temperature in normalized runtime view, got %v", got)
+ }
}
diff --git a/pkg/config/validate.go b/pkg/config/validate.go
index 529824b..1963e8d 100644
--- a/pkg/config/validate.go
+++ b/pkg/config/validate.go
@@ -478,6 +478,9 @@ func validateProviderConfig(path string, p ProviderConfig) []error {
if p.TimeoutSec <= 0 {
errs = append(errs, fmt.Errorf("%s.timeout_sec must be > 0", path))
}
+ if p.MaxTokens < 0 {
+ errs = append(errs, fmt.Errorf("%s.max_tokens must be >= 0", path))
+ }
switch authMode {
case "", "bearer", "oauth", "none", "hybrid":
default:
diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go
index 054c1ef..2a8c4e5 100644
--- a/pkg/providers/codex_provider.go
+++ b/pkg/providers/codex_provider.go
@@ -495,9 +495,7 @@ func (p *CodexProvider) doStreamAttempt(req *http.Request, attempt authAttempt,
if typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"])); typ == "response.completed" {
completed = true
if respObj, ok := obj["response"]; ok {
- if b, err := json.Marshal(respObj); err == nil {
- finalJSON = b
- }
+ finalJSON = mergeStreamFinalJSON(finalJSON, respObj)
}
}
}
diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go
index 306c591..43e301d 100644
--- a/pkg/providers/codex_provider_test.go
+++ b/pkg/providers/codex_provider_test.go
@@ -214,6 +214,27 @@ func TestCodexProviderChatFallsBackToHTTPStreamResponse(t *testing.T) {
}
}
+func TestCodexProviderChatMergesLateUsageFromStreamingCompletion(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+ _, _ = fmt.Fprint(w, "data: {\"type\":\"response.completed\",\"response\":{\"status\":\"completed\",\"output_text\":\"hello\"}}\n\n")
+ _, _ = fmt.Fprint(w, "data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"total_tokens\":3}}}\n\n")
+ }))
+ defer server.Close()
+
+ provider := NewCodexProvider("codex", "test-api-key", server.URL, "gpt-5.4", false, "", 5*time.Second, nil)
+ resp, err := provider.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-5.4", nil)
+ if err != nil {
+ t.Fatalf("Chat error: %v", err)
+ }
+ if resp.Content != "hello" {
+ t.Fatalf("unexpected response content: %q", resp.Content)
+ }
+ if resp.Usage == nil || resp.Usage.PromptTokens != 1 || resp.Usage.CompletionTokens != 2 || resp.Usage.TotalTokens != 3 {
+ t.Fatalf("unexpected usage: %#v", resp.Usage)
+ }
+}
+
func TestCodexHandleAttemptFailureMarksAPIKeyCooldown(t *testing.T) {
provider := NewCodexProvider("codex-websocket-failure", "test-api-key", "", "gpt-5.4", false, "", 5*time.Second, nil)
provider.handleAttemptFailure(authAttempt{kind: "api_key", token: "test-api-key"}, http.StatusTooManyRequests, []byte(`{"error":{"message":"rate limit exceeded"}}`))
diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go
index e8620b5..7da55ff 100644
--- a/pkg/providers/http_provider.go
+++ b/pkg/providers/http_provider.go
@@ -1022,15 +1022,13 @@ func (p *HTTPProvider) doStreamAttempt(req *http.Request, attempt authAttempt, o
if err := json.Unmarshal([]byte(payload), &obj); err == nil {
if typ := strings.TrimSpace(fmt.Sprintf("%v", obj["type"])); typ == "response.completed" {
if respObj, ok := obj["response"]; ok {
- if b, err := json.Marshal(respObj); err == nil {
- finalJSON = b
- }
+ finalJSON = mergeStreamFinalJSON(finalJSON, respObj)
}
}
if choices, ok := obj["choices"]; ok {
- if b, err := json.Marshal(map[string]interface{}{"choices": choices, "usage": obj["usage"]}); err == nil {
- finalJSON = b
- }
+ finalJSON = mergeStreamFinalJSON(finalJSON, map[string]interface{}{"choices": choices, "usage": obj["usage"]})
+ } else if _, ok := obj["usage"]; ok && len(finalJSON) > 0 {
+ finalJSON = mergeStreamFinalJSON(finalJSON, map[string]interface{}{"usage": obj["usage"]})
}
}
}
@@ -1049,6 +1047,56 @@ func (p *HTTPProvider) doStreamAttempt(req *http.Request, attempt authAttempt, o
return finalJSON, resp.StatusCode, ctype, false, nil
}
+func mergeStreamFinalJSON(existing []byte, incoming interface{}) []byte {
+ if incoming == nil {
+ return existing
+ }
+ incomingMap, ok := incoming.(map[string]interface{})
+ if !ok {
+ data, err := json.Marshal(incoming)
+ if err != nil {
+ return existing
+ }
+ return data
+ }
+ if len(existing) == 0 {
+ data, err := json.Marshal(incomingMap)
+ if err != nil {
+ return existing
+ }
+ return data
+ }
+ var merged map[string]interface{}
+ if err := json.Unmarshal(existing, &merged); err != nil || merged == nil {
+ merged = map[string]interface{}{}
+ }
+ merged = mergeStringAnyMaps(merged, incomingMap)
+ data, err := json.Marshal(merged)
+ if err != nil {
+ return existing
+ }
+ return data
+}
+
+func mergeStringAnyMaps(dst, src map[string]interface{}) map[string]interface{} {
+ if dst == nil {
+ dst = map[string]interface{}{}
+ }
+ for key, value := range src {
+ if value == nil {
+ continue
+ }
+ if nestedSrc, ok := value.(map[string]interface{}); ok {
+ if nestedDst, ok := dst[key].(map[string]interface{}); ok {
+ dst[key] = mergeStringAnyMaps(nestedDst, nestedSrc)
+ continue
+ }
+ }
+ dst[key] = value
+ }
+ return dst
+}
+
func shouldRetryOAuthQuota(status int, body []byte) bool {
_, retry := classifyOAuthFailure(status, body)
return retry
diff --git a/pkg/providers/oauth_test.go b/pkg/providers/oauth_test.go
index 78a6ca3..2856674 100644
--- a/pkg/providers/oauth_test.go
+++ b/pkg/providers/oauth_test.go
@@ -197,6 +197,44 @@ func TestHTTPProviderOAuthSwitchesAccountOnQuota(t *testing.T) {
}
}
+func TestHTTPProviderOpenAICompatStreamMergesLateUsage(t *testing.T) {
+ t.Parallel()
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/v1/chat/completions" {
+ http.NotFound(w, r)
+ return
+ }
+ w.Header().Set("Content-Type", "text/event-stream")
+ _, _ = w.Write([]byte("data: {\"choices\":[{\"index\":0,\"message\":{\"content\":\"hello\"},\"finish_reason\":\"stop\"}]}\n\n"))
+ _, _ = w.Write([]byte("data: {\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3}}\n\n"))
+ }))
+ defer server.Close()
+
+ provider := NewHTTPProvider("openai", "token", server.URL+"/v1", "gpt-test", false, "api_key", 5*time.Second, nil)
+ req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, server.URL+"/v1/chat/completions", nil)
+ if err != nil {
+ t.Fatalf("new request failed: %v", err)
+ }
+ body, status, _, _, err := provider.doStreamAttempt(req, authAttempt{kind: "api_key", token: "token"}, nil)
+ if err != nil {
+ t.Fatalf("stream attempt failed: %v", err)
+ }
+ if status != http.StatusOK {
+ t.Fatalf("unexpected status: %d", status)
+ }
+ resp, err := parseOpenAICompatResponse(body)
+ if err != nil {
+ t.Fatalf("parse response failed: %v", err)
+ }
+ if resp.Content != "hello" {
+ t.Fatalf("unexpected response content: %q", resp.Content)
+ }
+ if resp.Usage == nil || resp.Usage.PromptTokens != 1 || resp.Usage.CompletionTokens != 2 || resp.Usage.TotalTokens != 3 {
+ t.Fatalf("unexpected usage: %#v", resp.Usage)
+ }
+}
+
func TestOAuthManagerPreRefreshesExpiringSession(t *testing.T) {
t.Parallel()