From c5033e8484bf02c525d38b5fbcbc46b3c17fb927 Mon Sep 17 00:00:00 2001 From: DBT Date: Sat, 28 Feb 2026 15:06:55 +0000 Subject: [PATCH] provider fallback: switch to proxy_fallbacks on API failure in message/system loops --- pkg/agent/loop.go | 73 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 71 insertions(+), 2 deletions(-) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 9f2db01..414400c 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -58,6 +58,8 @@ type AgentLoop struct { intentHints map[string]string sessionRunMu sync.Mutex sessionRunLocks map[string]*sync.Mutex + providerNames []string + providerPool map[string]providers.LLMProvider } // StartupCompactionReport provides startup memory/session maintenance stats. @@ -236,6 +238,35 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers sessionRunLocks: map[string]*sync.Mutex{}, } + // Initialize provider fallback chain (primary + proxy_fallbacks). + loop.providerPool = map[string]providers.LLMProvider{} + loop.providerNames = []string{} + primaryName := cfg.Agents.Defaults.Proxy + if primaryName == "" { + primaryName = "proxy" + } + loop.providerPool[primaryName] = provider + loop.providerNames = append(loop.providerNames, primaryName) + for _, name := range cfg.Agents.Defaults.ProxyFallbacks { + if name == "" { + continue + } + dup := false + for _, existing := range loop.providerNames { + if existing == name { + dup = true + break + } + } + if dup { + continue + } + if p2, err := providers.CreateProviderByName(cfg, name); err == nil { + loop.providerPool[name] = p2 + loop.providerNames = append(loop.providerNames, name) + } + } + // Inject recursive run logic so subagents can use full tool-calling flows. subagentManager.SetRunFunc(func(ctx context.Context, task, channel, chatID string) (string, error) { sessionKey := fmt.Sprintf("subagent:%d", os.Getpid()) // Use PID/randomized key to reduce session key collisions. @@ -309,6 +340,27 @@ func (al *AgentLoop) lockSessionRun(sessionKey string) func() { return func() { mu.Unlock() } } +func (al *AgentLoop) tryFallbackProviders(ctx context.Context, messages []providers.Message, toolDefs []providers.ToolDefinition, options map[string]interface{}, primaryErr error) (*providers.LLMResponse, error) { + if len(al.providerNames) <= 1 { + return nil, primaryErr + } + lastErr := primaryErr + for i := 1; i < len(al.providerNames); i++ { + name := al.providerNames[i] + p, ok := al.providerPool[name] + if !ok || p == nil { + continue + } + resp, err := p.Chat(ctx, messages, toolDefs, al.model, options) + if err == nil { + logger.WarnCF("agent", "LLM fallback provider switched", map[string]interface{}{"provider": name}) + return resp, nil + } + lastErr = err + } + return nil, lastErr +} + func (al *AgentLoop) processInbound(ctx context.Context, msg bus.InboundMessage) { taskID := fmt.Sprintf("%s-%d", shortSessionKey(msg.SessionKey), time.Now().Unix()%100000) started := time.Now() @@ -686,6 +738,14 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, options) } + if err != nil { + if fb, ferr := al.tryFallbackProviders(ctx, messages, providerToolDefs, options, err); ferr == nil && fb != nil { + response = fb + err = nil + } else { + err = ferr + } + } if err != nil { errText := strings.ToLower(err.Error()) if strings.Contains(errText, "no tool call found for function call output") { @@ -1015,11 +1075,20 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe "tools_json": formatToolsForLog(providerToolDefs), }) - response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{ + options := map[string]interface{}{ "max_tokens": 8192, "temperature": 0.7, - }) + } + response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, options) + if err != nil { + if fb, ferr := al.tryFallbackProviders(ctx, messages, providerToolDefs, options, err); ferr == nil && fb != nil { + response = fb + err = nil + } else { + err = ferr + } + } if err != nil { errText := strings.ToLower(err.Error()) if strings.Contains(errText, "no tool call found for function call output") {