provider fallback: switch to proxy_fallbacks on API failure in message/system loops

This commit is contained in:
DBT
2026-02-28 15:06:55 +00:00
parent e729e50d02
commit c5033e8484

View File

@@ -58,6 +58,8 @@ type AgentLoop struct {
intentHints map[string]string
sessionRunMu sync.Mutex
sessionRunLocks map[string]*sync.Mutex
providerNames []string
providerPool map[string]providers.LLMProvider
}
// StartupCompactionReport provides startup memory/session maintenance stats.
@@ -236,6 +238,35 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
sessionRunLocks: map[string]*sync.Mutex{},
}
// Initialize provider fallback chain (primary + proxy_fallbacks).
loop.providerPool = map[string]providers.LLMProvider{}
loop.providerNames = []string{}
primaryName := cfg.Agents.Defaults.Proxy
if primaryName == "" {
primaryName = "proxy"
}
loop.providerPool[primaryName] = provider
loop.providerNames = append(loop.providerNames, primaryName)
for _, name := range cfg.Agents.Defaults.ProxyFallbacks {
if name == "" {
continue
}
dup := false
for _, existing := range loop.providerNames {
if existing == name {
dup = true
break
}
}
if dup {
continue
}
if p2, err := providers.CreateProviderByName(cfg, name); err == nil {
loop.providerPool[name] = p2
loop.providerNames = append(loop.providerNames, name)
}
}
// Inject recursive run logic so subagents can use full tool-calling flows.
subagentManager.SetRunFunc(func(ctx context.Context, task, channel, chatID string) (string, error) {
sessionKey := fmt.Sprintf("subagent:%d", os.Getpid()) // Use PID/randomized key to reduce session key collisions.
@@ -309,6 +340,27 @@ func (al *AgentLoop) lockSessionRun(sessionKey string) func() {
return func() { mu.Unlock() }
}
func (al *AgentLoop) tryFallbackProviders(ctx context.Context, messages []providers.Message, toolDefs []providers.ToolDefinition, options map[string]interface{}, primaryErr error) (*providers.LLMResponse, error) {
if len(al.providerNames) <= 1 {
return nil, primaryErr
}
lastErr := primaryErr
for i := 1; i < len(al.providerNames); i++ {
name := al.providerNames[i]
p, ok := al.providerPool[name]
if !ok || p == nil {
continue
}
resp, err := p.Chat(ctx, messages, toolDefs, al.model, options)
if err == nil {
logger.WarnCF("agent", "LLM fallback provider switched", map[string]interface{}{"provider": name})
return resp, nil
}
lastErr = err
}
return nil, lastErr
}
func (al *AgentLoop) processInbound(ctx context.Context, msg bus.InboundMessage) {
taskID := fmt.Sprintf("%s-%d", shortSessionKey(msg.SessionKey), time.Now().Unix()%100000)
started := time.Now()
@@ -686,6 +738,14 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, options)
}
if err != nil {
if fb, ferr := al.tryFallbackProviders(ctx, messages, providerToolDefs, options, err); ferr == nil && fb != nil {
response = fb
err = nil
} else {
err = ferr
}
}
if err != nil {
errText := strings.ToLower(err.Error())
if strings.Contains(errText, "no tool call found for function call output") {
@@ -1015,11 +1075,20 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
"tools_json": formatToolsForLog(providerToolDefs),
})
response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
options := map[string]interface{}{
"max_tokens": 8192,
"temperature": 0.7,
})
}
response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, options)
if err != nil {
if fb, ferr := al.tryFallbackProviders(ctx, messages, providerToolDefs, options, err); ferr == nil && fb != nil {
response = fb
err = nil
} else {
err = ferr
}
}
if err != nil {
errText := strings.ToLower(err.Error())
if strings.Contains(errText, "no tool call found for function call output") {