From ddca0605c4991b688f5d28715604c9489f354aeb Mon Sep 17 00:00:00 2001 From: lpf Date: Wed, 18 Feb 2026 21:58:03 +0800 Subject: [PATCH] Add multi-service provider mode --- cmd/clawgo/main.go | 26 +- config.example.json | 16 +- pkg/agent/loop.go | 307 +++++++++++---- pkg/agent/loop_fallback_test.go | 109 +++++- pkg/agent/loop_model_switch_test.go | 32 +- pkg/config/config.go | 24 +- pkg/config/validate.go | 54 ++- pkg/providers/openai_provider.go | 354 +++++++++++++++--- ...http_provider_test.go => provider_test.go} | 25 +- 9 files changed, 776 insertions(+), 171 deletions(-) rename pkg/providers/{http_provider_test.go => provider_test.go} (83%) diff --git a/cmd/clawgo/main.go b/cmd/clawgo/main.go index 9317ddf..0ba042b 100644 --- a/cmd/clawgo/main.go +++ b/cmd/clawgo/main.go @@ -1248,9 +1248,16 @@ func buildGatewayRuntime(ctx context.Context, cfg *config.Config, msgBus *bus.Me return nil, nil, fmt.Errorf("create channel manager: %w", err) } + activeProvider := cfg.Providers.Proxy + if name := strings.TrimSpace(cfg.Agents.Defaults.Proxy); name != "" && name != "proxy" { + if p, ok := cfg.Providers.Proxies[name]; ok { + activeProvider = p + } + } + var transcriber *voice.GroqTranscriber - if cfg.Providers.Proxy.APIKey != "" && strings.Contains(cfg.Providers.Proxy.APIBase, "groq.com") { - transcriber = voice.NewGroqTranscriber(cfg.Providers.Proxy.APIKey) + if activeProvider.APIKey != "" && strings.Contains(activeProvider.APIBase, "groq.com") { + transcriber = voice.NewGroqTranscriber(activeProvider.APIKey) logger.InfoC("voice", "Groq voice transcription enabled via Proxy config") } @@ -1501,7 +1508,20 @@ func statusCmd() { } if _, err := os.Stat(configPath); err == nil { - fmt.Printf("Model: %s\n", cfg.Agents.Defaults.Model) + activeProvider := cfg.Providers.Proxy + if name := strings.TrimSpace(cfg.Agents.Defaults.Proxy); name != "" && name != "proxy" { + if p, ok := cfg.Providers.Proxies[name]; ok { + activeProvider = p + } + } + activeModel := "" + for _, m := range activeProvider.Models { + if s := strings.TrimSpace(m); s != "" { + activeModel = s + break + } + } + fmt.Printf("Model: %s\n", activeModel) fmt.Printf("CLIProxyAPI Base: %s\n", cfg.Providers.Proxy.APIBase) hasKey := cfg.Providers.Proxy.APIKey != "" status := "not set" diff --git a/config.example.json b/config.example.json index 0cd38a1..61cabdc 100644 --- a/config.example.json +++ b/config.example.json @@ -2,8 +2,8 @@ "agents": { "defaults": { "workspace": "~/.clawgo/workspace", - "model": "glm-4.7", - "model_fallbacks": ["gpt-4o-mini", "deepseek-chat"], + "proxy": "proxy", + "proxy_fallbacks": ["backup"], "max_tokens": 8192, "temperature": 0.7, "max_tool_iterations": 20, @@ -57,8 +57,20 @@ "proxy": { "api_key": "YOUR_CLIPROXYAPI_KEY", "api_base": "http://localhost:8080/v1", + "protocol": "chat_completions", + "models": ["glm-4.7", "gpt-4o-mini"], "auth": "bearer", "timeout_sec": 90 + }, + "proxies": { + "backup": { + "api_key": "YOUR_BACKUP_PROXY_KEY", + "api_base": "http://localhost:8081/v1", + "protocol": "responses", + "models": ["gpt-4o-mini", "deepseek-chat"], + "auth": "bearer", + "timeout_sec": 90 + } } }, "tools": { diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index f990a90..00518fc 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -14,6 +14,7 @@ import ( "os" "path/filepath" "regexp" + "sort" "strconv" "strings" "sync" @@ -67,25 +68,28 @@ type autonomySession struct { } type AgentLoop struct { - bus *bus.MessageBus - provider providers.LLMProvider - workspace string - model string - modelFallbacks []string - maxIterations int - sessions *session.SessionManager - contextBuilder *ContextBuilder - tools *tools.ToolRegistry - orchestrator *tools.Orchestrator - running atomic.Bool - compactionCfg config.ContextCompactionConfig - llmCallTimeout time.Duration - workersMu sync.Mutex - workers map[string]*sessionWorker - autoLearnMu sync.Mutex - autoLearners map[string]*autoLearner - autonomyMu sync.Mutex - autonomyBySess map[string]*autonomySession + bus *bus.MessageBus + provider providers.LLMProvider + providersByProxy map[string]providers.LLMProvider + modelsByProxy map[string][]string + proxy string + proxyFallbacks []string + workspace string + model string + maxIterations int + sessions *session.SessionManager + contextBuilder *ContextBuilder + tools *tools.ToolRegistry + orchestrator *tools.Orchestrator + running atomic.Bool + compactionCfg config.ContextCompactionConfig + llmCallTimeout time.Duration + workersMu sync.Mutex + workers map[string]*sessionWorker + autoLearnMu sync.Mutex + autoLearners map[string]*autoLearner + autonomyMu sync.Mutex + autonomyBySess map[string]*autonomySession } type taskExecutionDirectives struct { @@ -218,22 +222,51 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers sessionsManager := session.NewSessionManager(filepath.Join(filepath.Dir(cfg.WorkspacePath()), "sessions")) + providersByProxy, err := providers.CreateProviders(cfg) + if err != nil { + logger.WarnCF("agent", "Create providers map failed, fallback to single provider mode", map[string]interface{}{ + logger.FieldError: err.Error(), + }) + providersByProxy = map[string]providers.LLMProvider{ + "proxy": provider, + } + } + modelsByProxy := map[string][]string{} + for _, name := range providers.ListProviderNames(cfg) { + modelsByProxy[name] = providers.GetProviderModels(cfg, name) + } + + primaryProxy := strings.TrimSpace(cfg.Agents.Defaults.Proxy) + if primaryProxy == "" { + primaryProxy = "proxy" + } + if p, ok := providersByProxy[primaryProxy]; ok { + provider = p + } else if p, ok := providersByProxy["proxy"]; ok { + primaryProxy = "proxy" + provider = p + } + defaultModel := defaultModelFromModels(modelsByProxy[primaryProxy], provider) + loop := &AgentLoop{ - bus: msgBus, - provider: provider, - workspace: workspace, - model: cfg.Agents.Defaults.Model, - modelFallbacks: cfg.Agents.Defaults.ModelFallbacks, - maxIterations: cfg.Agents.Defaults.MaxToolIterations, - sessions: sessionsManager, - contextBuilder: NewContextBuilder(workspace, cfg.Memory, func() []string { return toolsRegistry.GetSummaries() }), - tools: toolsRegistry, - orchestrator: orchestrator, - compactionCfg: cfg.Agents.Defaults.ContextCompaction, - llmCallTimeout: time.Duration(cfg.Providers.Proxy.TimeoutSec) * time.Second, - workers: make(map[string]*sessionWorker), - autoLearners: make(map[string]*autoLearner), - autonomyBySess: make(map[string]*autonomySession), + bus: msgBus, + provider: provider, + providersByProxy: providersByProxy, + modelsByProxy: modelsByProxy, + proxy: primaryProxy, + proxyFallbacks: parseStringList(cfg.Agents.Defaults.ProxyFallbacks), + workspace: workspace, + model: defaultModel, + maxIterations: cfg.Agents.Defaults.MaxToolIterations, + sessions: sessionsManager, + contextBuilder: NewContextBuilder(workspace, cfg.Memory, func() []string { return toolsRegistry.GetSummaries() }), + tools: toolsRegistry, + orchestrator: orchestrator, + compactionCfg: cfg.Agents.Defaults.ContextCompaction, + llmCallTimeout: time.Duration(cfg.Providers.Proxy.TimeoutSec) * time.Second, + workers: make(map[string]*sessionWorker), + autoLearners: make(map[string]*autoLearner), + autonomyBySess: make(map[string]*autonomySession), } // 注入递归运行逻辑,使 subagent 具备 full tool-calling 能力 @@ -1652,38 +1685,102 @@ func (al *AgentLoop) callLLMWithModelFallback( tools []providers.ToolDefinition, options map[string]interface{}, ) (*providers.LLMResponse, error) { - candidates := al.modelCandidates() + if len(al.providersByProxy) == 0 { + candidates := al.modelCandidates() + var lastErr error + + for idx, model := range candidates { + response, err := al.provider.Chat(ctx, messages, tools, model, options) + if err == nil { + if al.model != model { + logger.WarnCF("agent", "Model switched after quota/rate-limit error", map[string]interface{}{ + "from_model": al.model, + "to_model": model, + }) + al.model = model + } + return response, nil + } + + lastErr = err + if !shouldRetryWithFallbackModel(err) { + return nil, err + } + + if idx < len(candidates)-1 { + logger.DebugCF("agent", "Model request failed, trying fallback model", map[string]interface{}{ + "failed_model": model, + "next_model": candidates[idx+1], + logger.FieldError: err.Error(), + }) + continue + } + } + + return nil, fmt.Errorf("all configured models failed; last error: %w", lastErr) + } + + proxyCandidates := al.proxyCandidates() var lastErr error - for idx, model := range candidates { - response, err := al.provider.Chat(ctx, messages, tools, model, options) - if err == nil { - if al.model != model { - logger.WarnCF("agent", "Model switched after quota/rate-limit error", map[string]interface{}{ - "from_model": al.model, - "to_model": model, - }) - al.model = model - } - return response, nil - } - - lastErr = err - if !shouldRetryWithFallbackModel(err) { - return nil, err - } - - if idx < len(candidates)-1 { - logger.DebugCF("agent", "Model request failed, trying fallback model", map[string]interface{}{ - "failed_model": model, - "next_model": candidates[idx+1], - logger.FieldError: err.Error(), - }) + for pidx, proxyName := range proxyCandidates { + proxyProvider, ok := al.providersByProxy[proxyName] + if !ok || proxyProvider == nil { continue } + modelCandidates := al.modelCandidatesForProxy(proxyName) + if len(modelCandidates) == 0 { + continue + } + + for midx, model := range modelCandidates { + response, err := proxyProvider.Chat(ctx, messages, tools, model, options) + if err == nil { + if al.proxy != proxyName { + logger.WarnCF("agent", "Proxy switched after model unavailability", map[string]interface{}{ + "from_proxy": al.proxy, + "to_proxy": proxyName, + }) + al.proxy = proxyName + al.provider = proxyProvider + } + if al.model != model { + logger.WarnCF("agent", "Model switched after availability error", map[string]interface{}{ + "from_model": al.model, + "to_model": model, + "proxy": proxyName, + }) + al.model = model + } + return response, nil + } + + lastErr = err + if !shouldRetryWithFallbackModel(err) { + return nil, err + } + + if midx < len(modelCandidates)-1 { + logger.DebugCF("agent", "Model request failed, trying next model in proxy", map[string]interface{}{ + "proxy": proxyName, + "failed_model": model, + "next_model": modelCandidates[midx+1], + logger.FieldError: err.Error(), + }) + continue + } + + if pidx < len(proxyCandidates)-1 { + logger.DebugCF("agent", "All models failed in proxy, trying next proxy", map[string]interface{}{ + "failed_proxy": proxyName, + "next_proxy": proxyCandidates[pidx+1], + logger.FieldError: err.Error(), + }) + } + } } - return nil, fmt.Errorf("all configured models failed; last error: %w", lastErr) + return nil, fmt.Errorf("all configured proxies/models failed; last error: %w", lastErr) } func (al *AgentLoop) modelCandidates() []string { @@ -1700,13 +1797,67 @@ func (al *AgentLoop) modelCandidates() []string { } add(al.model) - for _, m := range al.modelFallbacks { + + return candidates +} + +func (al *AgentLoop) modelCandidatesForProxy(proxyName string) []string { + candidates := []string{} + seen := map[string]bool{} + + add := func(model string) { + m := strings.TrimSpace(model) + if m == "" || seen[m] { + return + } + seen[m] = true + candidates = append(candidates, m) + } + + add(al.model) + + models := al.modelsByProxy[proxyName] + for _, m := range models { add(m) } return candidates } +func (al *AgentLoop) proxyCandidates() []string { + candidates := []string{} + seen := map[string]bool{} + add := func(name string) { + n := strings.TrimSpace(name) + if n == "" || seen[n] { + return + } + if _, ok := al.providersByProxy[n]; !ok { + return + } + seen[n] = true + candidates = append(candidates, n) + } + + add(al.proxy) + for _, n := range al.proxyFallbacks { + add(n) + } + + rest := make([]string, 0, len(al.providersByProxy)) + for name := range al.providersByProxy { + if seen[name] { + continue + } + rest = append(rest, name) + } + sort.Strings(rest) + for _, name := range rest { + add(name) + } + return candidates +} + func isQuotaOrRateLimitError(err error) bool { if err == nil { return false @@ -2421,7 +2572,7 @@ func (al *AgentLoop) handleSlashCommand(ctx context.Context, msg bus.InboundMess return true, "", fmt.Errorf("status failed: %w", err) } return true, fmt.Sprintf("Model: %s\nAPI Base: %s\nLogging: %v\nConfig: %s", - cfg.Agents.Defaults.Model, + al.model, cfg.Providers.Proxy.APIBase, cfg.Logging.Enabled, al.getConfigPathForCommands(), @@ -2702,18 +2853,34 @@ func (al *AgentLoop) triggerGatewayReloadFromAgent() (bool, error) { func (al *AgentLoop) applyRuntimeModelConfig(path string, value interface{}) { switch path { - case "agents.defaults.model": - newModel := strings.TrimSpace(fmt.Sprintf("%v", value)) - if newModel == "" { - return + case "agents.defaults.proxy": + newProxy := strings.TrimSpace(fmt.Sprintf("%v", value)) + if newProxy != "" { + al.proxy = newProxy + if p, ok := al.providersByProxy[newProxy]; ok { + al.provider = p + } + al.model = defaultModelFromModels(al.modelsByProxy[newProxy], al.provider) } - al.model = newModel - case "agents.defaults.model_fallbacks": - al.modelFallbacks = parseModelFallbacks(value) + case "agents.defaults.proxy_fallbacks": + al.proxyFallbacks = parseStringList(value) } } -func parseModelFallbacks(value interface{}) []string { +func defaultModelFromModels(models []string, provider providers.LLMProvider) string { + for _, m := range models { + model := strings.TrimSpace(m) + if model != "" { + return model + } + } + if provider != nil { + return strings.TrimSpace(provider.GetDefaultModel()) + } + return "" +} + +func parseStringList(value interface{}) []string { switch v := value.(type) { case []string: out := make([]string, 0, len(v)) diff --git a/pkg/agent/loop_fallback_test.go b/pkg/agent/loop_fallback_test.go index 4fabe92..caaba61 100644 --- a/pkg/agent/loop_fallback_test.go +++ b/pkg/agent/loop_fallback_test.go @@ -39,9 +39,15 @@ func TestCallLLMWithModelFallback_RetriesOnUnknownProvider(t *testing.T) { } al := &AgentLoop{ - provider: p, - model: "gemini-3-flash", - modelFallbacks: []string{"gpt-4o-mini"}, + provider: p, + proxy: "proxy", + model: "gemini-3-flash", + providersByProxy: map[string]providers.LLMProvider{ + "proxy": p, + }, + modelsByProxy: map[string][]string{ + "proxy": []string{"gemini-3-flash", "gpt-4o-mini"}, + }, } resp, err := al.callLLMWithModelFallback(context.Background(), nil, nil, nil) @@ -71,9 +77,15 @@ func TestCallLLMWithModelFallback_RetriesOnGateway502(t *testing.T) { } al := &AgentLoop{ - provider: p, - model: "gemini-3-flash", - modelFallbacks: []string{"gpt-4o-mini"}, + provider: p, + proxy: "proxy", + model: "gemini-3-flash", + providersByProxy: map[string]providers.LLMProvider{ + "proxy": p, + }, + modelsByProxy: map[string][]string{ + "proxy": []string{"gemini-3-flash", "gpt-4o-mini"}, + }, } resp, err := al.callLLMWithModelFallback(context.Background(), nil, nil, nil) @@ -100,9 +112,15 @@ func TestCallLLMWithModelFallback_RetriesOnGateway524(t *testing.T) { } al := &AgentLoop{ - provider: p, - model: "gemini-3-flash", - modelFallbacks: []string{"gpt-4o-mini"}, + provider: p, + proxy: "proxy", + model: "gemini-3-flash", + providersByProxy: map[string]providers.LLMProvider{ + "proxy": p, + }, + modelsByProxy: map[string][]string{ + "proxy": []string{"gemini-3-flash", "gpt-4o-mini"}, + }, } resp, err := al.callLLMWithModelFallback(context.Background(), nil, nil, nil) @@ -129,9 +147,15 @@ func TestCallLLMWithModelFallback_RetriesOnAuthUnavailable500(t *testing.T) { } al := &AgentLoop{ - provider: p, - model: "gemini-3-flash", - modelFallbacks: []string{"gpt-4o-mini"}, + provider: p, + proxy: "proxy", + model: "gemini-3-flash", + providersByProxy: map[string]providers.LLMProvider{ + "proxy": p, + }, + modelsByProxy: map[string][]string{ + "proxy": []string{"gemini-3-flash", "gpt-4o-mini"}, + }, } resp, err := al.callLLMWithModelFallback(context.Background(), nil, nil, nil) @@ -157,9 +181,15 @@ func TestCallLLMWithModelFallback_NoRetryOnNonRetryableError(t *testing.T) { } al := &AgentLoop{ - provider: p, - model: "gemini-3-flash", - modelFallbacks: []string{"gpt-4o-mini"}, + provider: p, + proxy: "proxy", + model: "gemini-3-flash", + providersByProxy: map[string]providers.LLMProvider{ + "proxy": p, + }, + modelsByProxy: map[string][]string{ + "proxy": []string{"gemini-3-flash", "gpt-4o-mini"}, + }, } _, err := al.callLLMWithModelFallback(context.Background(), nil, nil, nil) @@ -171,6 +201,55 @@ func TestCallLLMWithModelFallback_NoRetryOnNonRetryableError(t *testing.T) { } } +func TestCallLLMWithModelFallback_SwitchesProxyAfterProxyModelsExhausted(t *testing.T) { + primary := &fallbackTestProvider{ + byModel: map[string]fallbackResult{ + "gemini-3-flash": {err: fmt.Errorf(`API error (status 502): {"error":{"message":"unknown provider for model gemini-3-flash"}}`)}, + "gpt-4o-mini": {err: fmt.Errorf(`API error (status 400): {"error":{"message":"model not found"}}`)}, + }, + } + backup := &fallbackTestProvider{ + byModel: map[string]fallbackResult{ + "gemini-3-flash": {err: fmt.Errorf(`API error (status 400): {"error":{"message":"model not found"}}`)}, + "deepseek-chat": {resp: &providers.LLMResponse{Content: "ok"}}, + }, + } + + al := &AgentLoop{ + proxy: "primary", + proxyFallbacks: []string{"backup"}, + model: "gemini-3-flash", + providersByProxy: map[string]providers.LLMProvider{ + "primary": primary, + "backup": backup, + }, + modelsByProxy: map[string][]string{ + "primary": []string{"gemini-3-flash", "gpt-4o-mini"}, + "backup": []string{"deepseek-chat"}, + }, + } + + resp, err := al.callLLMWithModelFallback(context.Background(), nil, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp == nil || resp.Content != "ok" { + t.Fatalf("unexpected response: %+v", resp) + } + if al.proxy != "backup" { + t.Fatalf("expected proxy switch to backup, got %q", al.proxy) + } + if al.model != "deepseek-chat" { + t.Fatalf("expected model switch to deepseek-chat, got %q", al.model) + } + if len(primary.called) != 2 { + t.Fatalf("expected 2 model attempts in primary, got %d (%v)", len(primary.called), primary.called) + } + if len(backup.called) != 2 || backup.called[1] != "deepseek-chat" { + t.Fatalf("unexpected backup attempts: %v", backup.called) + } +} + func TestShouldRetryWithFallbackModel_UnknownProviderError(t *testing.T) { err := fmt.Errorf(`API error (status 502): {"error":{"message":"unknown provider for model gemini-3-flash","type":"servererror"}}`) if !shouldRetryWithFallbackModel(err) { diff --git a/pkg/agent/loop_model_switch_test.go b/pkg/agent/loop_model_switch_test.go index b03d78d..7a4b84f 100644 --- a/pkg/agent/loop_model_switch_test.go +++ b/pkg/agent/loop_model_switch_test.go @@ -2,28 +2,20 @@ package agent import "testing" -func TestApplyRuntimeModelConfig_Model(t *testing.T) { - al := &AgentLoop{model: "old-model"} - al.applyRuntimeModelConfig("agents.defaults.model", "new-model") - if al.model != "new-model" { - t.Fatalf("expected runtime model updated, got %q", al.model) +func TestApplyRuntimeModelConfig_ProxyFallbacks(t *testing.T) { + al := &AgentLoop{proxyFallbacks: []string{"old-proxy"}} + al.applyRuntimeModelConfig("agents.defaults.proxy_fallbacks", []interface{}{"backup-a", "", "backup-b"}) + if len(al.proxyFallbacks) != 2 { + t.Fatalf("expected 2 fallbacks, got %d: %v", len(al.proxyFallbacks), al.proxyFallbacks) + } + if al.proxyFallbacks[0] != "backup-a" || al.proxyFallbacks[1] != "backup-b" { + t.Fatalf("unexpected fallbacks: %v", al.proxyFallbacks) } } -func TestApplyRuntimeModelConfig_ModelFallbacks(t *testing.T) { - al := &AgentLoop{modelFallbacks: []string{"old-fallback"}} - al.applyRuntimeModelConfig("agents.defaults.model_fallbacks", []interface{}{"gpt-4o-mini", "", "claude-3-5-sonnet"}) - if len(al.modelFallbacks) != 2 { - t.Fatalf("expected 2 fallbacks, got %d: %v", len(al.modelFallbacks), al.modelFallbacks) - } - if al.modelFallbacks[0] != "gpt-4o-mini" || al.modelFallbacks[1] != "claude-3-5-sonnet" { - t.Fatalf("unexpected fallbacks: %v", al.modelFallbacks) - } -} - -func TestParseModelFallbacks_StringValue(t *testing.T) { - fallbacks := parseModelFallbacks("gpt-4o-mini") - if len(fallbacks) != 1 || fallbacks[0] != "gpt-4o-mini" { - t.Fatalf("unexpected parse result: %v", fallbacks) +func TestParseStringList_StringValue(t *testing.T) { + out := parseStringList("backup-a") + if len(out) != 1 || out[0] != "backup-a" { + t.Fatalf("unexpected parse result: %v", out) } } diff --git a/pkg/config/config.go b/pkg/config/config.go index 247eb9b..032ea8a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -29,8 +29,8 @@ type AgentsConfig struct { type AgentDefaults struct { Workspace string `json:"workspace" env:"CLAWGO_AGENTS_DEFAULTS_WORKSPACE"` - Model string `json:"model" env:"CLAWGO_AGENTS_DEFAULTS_MODEL"` - ModelFallbacks []string `json:"model_fallbacks" env:"CLAWGO_AGENTS_DEFAULTS_MODEL_FALLBACKS"` + Proxy string `json:"proxy" env:"CLAWGO_AGENTS_DEFAULTS_PROXY"` + ProxyFallbacks []string `json:"proxy_fallbacks" env:"CLAWGO_AGENTS_DEFAULTS_PROXY_FALLBACKS"` MaxTokens int `json:"max_tokens" env:"CLAWGO_AGENTS_DEFAULTS_MAX_TOKENS"` Temperature float64 `json:"temperature" env:"CLAWGO_AGENTS_DEFAULTS_TEMPERATURE"` MaxToolIterations int `json:"max_tool_iterations" env:"CLAWGO_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` @@ -104,14 +104,17 @@ type DingTalkConfig struct { } type ProvidersConfig struct { - Proxy ProviderConfig `json:"proxy"` + Proxy ProviderConfig `json:"proxy"` + Proxies map[string]ProviderConfig `json:"proxies"` } type ProviderConfig struct { - APIKey string `json:"api_key" env:"CLAWGO_PROVIDERS_{{.Name}}_API_KEY"` - APIBase string `json:"api_base" env:"CLAWGO_PROVIDERS_{{.Name}}_API_BASE"` - Auth string `json:"auth" env:"CLAWGO_PROVIDERS_{{.Name}}_AUTH"` - TimeoutSec int `json:"timeout_sec" env:"CLAWGO_PROVIDERS_PROXY_TIMEOUT_SEC"` + APIKey string `json:"api_key" env:"CLAWGO_PROVIDERS_{{.Name}}_API_KEY"` + APIBase string `json:"api_base" env:"CLAWGO_PROVIDERS_{{.Name}}_API_BASE"` + Protocol string `json:"protocol" env:"CLAWGO_PROVIDERS_{{.Name}}_PROTOCOL"` + Models []string `json:"models" env:"CLAWGO_PROVIDERS_{{.Name}}_MODELS"` + Auth string `json:"auth" env:"CLAWGO_PROVIDERS_{{.Name}}_AUTH"` + TimeoutSec int `json:"timeout_sec" env:"CLAWGO_PROVIDERS_PROXY_TIMEOUT_SEC"` } type GatewayConfig struct { @@ -228,8 +231,8 @@ func DefaultConfig() *Config { Agents: AgentsConfig{ Defaults: AgentDefaults{ Workspace: filepath.Join(configDir, "workspace"), - Model: "glm-4.7", - ModelFallbacks: []string{}, + Proxy: "proxy", + ProxyFallbacks: []string{}, MaxTokens: 8192, Temperature: 0.7, MaxToolIterations: 20, @@ -288,8 +291,11 @@ func DefaultConfig() *Config { Providers: ProvidersConfig{ Proxy: ProviderConfig{ APIBase: "http://localhost:8080/v1", + Protocol: "chat_completions", + Models: []string{"glm-4.7"}, TimeoutSec: 90, }, + Proxies: map[string]ProviderConfig{}, }, Gateway: GatewayConfig{ Host: "0.0.0.0", diff --git a/pkg/config/validate.go b/pkg/config/validate.go index f466527..5324522 100644 --- a/pkg/config/validate.go +++ b/pkg/config/validate.go @@ -11,9 +11,6 @@ func Validate(cfg *Config) []error { var errs []error - if cfg.Agents.Defaults.Model == "" { - errs = append(errs, fmt.Errorf("agents.defaults.model is required")) - } if cfg.Agents.Defaults.MaxToolIterations <= 0 { errs = append(errs, fmt.Errorf("agents.defaults.max_tool_iterations must be > 0")) } @@ -36,11 +33,22 @@ func Validate(cfg *Config) []error { } } - if cfg.Providers.Proxy.APIBase == "" { - errs = append(errs, fmt.Errorf("providers.proxy.api_base is required")) + if len(cfg.Providers.Proxies) == 0 { + errs = append(errs, validateProviderConfig("providers.proxy", cfg.Providers.Proxy)...) + } else { + for name, p := range cfg.Providers.Proxies { + errs = append(errs, validateProviderConfig("providers.proxies."+name, p)...) + } } - if cfg.Providers.Proxy.TimeoutSec <= 0 { - errs = append(errs, fmt.Errorf("providers.proxy.timeout_sec must be > 0")) + if cfg.Agents.Defaults.Proxy != "" { + if !providerExists(cfg, cfg.Agents.Defaults.Proxy) { + errs = append(errs, fmt.Errorf("agents.defaults.proxy %q not found in providers", cfg.Agents.Defaults.Proxy)) + } + } + for _, name := range cfg.Agents.Defaults.ProxyFallbacks { + if !providerExists(cfg, name) { + errs = append(errs, fmt.Errorf("agents.defaults.proxy_fallbacks contains unknown proxy %q", name)) + } } if cfg.Gateway.Port <= 0 || cfg.Gateway.Port > 65535 { @@ -129,3 +137,35 @@ func Validate(cfg *Config) []error { return errs } + +func validateProviderConfig(path string, p ProviderConfig) []error { + var errs []error + if p.APIBase == "" { + errs = append(errs, fmt.Errorf("%s.api_base is required", path)) + } + if p.Protocol != "" { + switch p.Protocol { + case "chat_completions", "responses": + default: + errs = append(errs, fmt.Errorf("%s.protocol must be one of: chat_completions, responses", path)) + } + } + if p.TimeoutSec <= 0 { + errs = append(errs, fmt.Errorf("%s.timeout_sec must be > 0", path)) + } + if len(p.Models) == 0 { + errs = append(errs, fmt.Errorf("%s.models must contain at least one model", path)) + } + return errs +} + +func providerExists(cfg *Config, name string) bool { + if name == "proxy" && cfg.Providers.Proxy.APIBase != "" { + return true + } + if cfg.Providers.Proxies == nil { + return false + } + _, ok := cfg.Providers.Proxies[name] + return ok +} diff --git a/pkg/providers/openai_provider.go b/pkg/providers/openai_provider.go index dda4ace..2d683d4 100644 --- a/pkg/providers/openai_provider.go +++ b/pkg/providers/openai_provider.go @@ -21,21 +21,31 @@ import ( "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/option" "github.com/openai/openai-go/v3/packages/param" + "github.com/openai/openai-go/v3/responses" "github.com/openai/openai-go/v3/shared" "github.com/openai/openai-go/v3/shared/constant" ) +const ( + ProtocolChatCompletions = "chat_completions" + ProtocolResponses = "responses" +) + type HTTPProvider struct { - apiKey string - apiBase string - authMode string - timeout time.Duration - httpClient *http.Client - client openai.Client + apiKey string + apiBase string + protocol string + defaultModel string + authMode string + timeout time.Duration + httpClient *http.Client + client openai.Client } -func NewHTTPProvider(apiKey, apiBase, authMode string, timeout time.Duration) *HTTPProvider { +func NewHTTPProvider(apiKey, apiBase, protocol, defaultModel, authMode string, timeout time.Duration) *HTTPProvider { normalizedBase := normalizeAPIBase(apiBase) + resolvedProtocol := normalizeProtocol(protocol) + resolvedDefaultModel := strings.TrimSpace(defaultModel) httpClient := &http.Client{Timeout: timeout} clientOpts := []option.RequestOption{ option.WithBaseURL(normalizedBase), @@ -54,12 +64,14 @@ func NewHTTPProvider(apiKey, apiBase, authMode string, timeout time.Duration) *H } return &HTTPProvider{ - apiKey: apiKey, - apiBase: normalizedBase, - authMode: authMode, - timeout: timeout, - httpClient: httpClient, - client: openai.NewClient(clientOpts...), + apiKey: apiKey, + apiBase: normalizedBase, + protocol: resolvedProtocol, + defaultModel: resolvedDefaultModel, + authMode: authMode, + timeout: timeout, + httpClient: httpClient, + client: openai.NewClient(clientOpts...), } } @@ -70,17 +82,29 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too logger.DebugCF("provider", "OpenAI SDK chat request", map[string]interface{}{ "api_base": p.apiBase, + "protocol": p.protocol, "model": model, "messages_count": len(messages), "tools_count": len(tools), "timeout": p.timeout.String(), }) + if p.protocol == ProtocolResponses { + params, err := buildResponsesParams(messages, tools, model, options) + if err != nil { + return nil, err + } + resp, err := p.client.Responses.New(ctx, params) + if err != nil { + return nil, fmt.Errorf("API error: %w", err) + } + return mapResponsesAPIResponse(resp), nil + } + params, err := buildChatParams(messages, tools, model, options) if err != nil { return nil, err } - resp, err := p.client.Chat.Completions.New(ctx, params) if err != nil { return nil, fmt.Errorf("API error: %w", err) @@ -88,6 +112,84 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too return mapChatCompletionResponse(resp), nil } +func buildResponsesParams(messages []Message, tools []ToolDefinition, model string, opts map[string]interface{}) (responses.ResponseNewParams, error) { + params := responses.ResponseNewParams{ + Model: model, + Input: responses.ResponseNewParamsInputUnion{ + OfInputItemList: make(responses.ResponseInputParam, 0, len(messages)), + }, + } + + for _, msg := range messages { + inputItems := toResponsesInputItems(msg) + params.Input.OfInputItemList = append(params.Input.OfInputItemList, inputItems...) + } + + if len(tools) > 0 { + params.Tools = make([]responses.ToolUnionParam, 0, len(tools)) + for _, t := range tools { + tool := responses.ToolParamOfFunction(t.Function.Name, t.Function.Parameters, false) + if t.Function.Description != "" && tool.OfFunction != nil { + tool.OfFunction.Description = param.NewOpt(t.Function.Description) + } + params.Tools = append(params.Tools, tool) + } + params.ToolChoice.OfToolChoiceMode = param.NewOpt(responses.ToolChoiceOptionsAuto) + } + + if maxTokens, ok := int64FromOption(opts, "max_tokens"); ok { + params.MaxOutputTokens = param.NewOpt(maxTokens) + } + if temperature, ok := float64FromOption(opts, "temperature"); ok { + params.Temperature = param.NewOpt(temperature) + } + + return params, nil +} + +func toResponsesInputItems(msg Message) []responses.ResponseInputItemUnionParam { + role := strings.ToLower(strings.TrimSpace(msg.Role)) + switch role { + case "system": + return []responses.ResponseInputItemUnionParam{ + responses.ResponseInputItemParamOfMessage(msg.Content, responses.EasyInputMessageRoleSystem), + } + case "developer": + return []responses.ResponseInputItemUnionParam{ + responses.ResponseInputItemParamOfMessage(msg.Content, responses.EasyInputMessageRoleDeveloper), + } + case "assistant": + items := []responses.ResponseInputItemUnionParam{ + responses.ResponseInputItemParamOfMessage(msg.Content, responses.EasyInputMessageRoleAssistant), + } + for _, tc := range msg.ToolCalls { + name, arguments := normalizeOutboundToolCall(tc) + if name == "" { + continue + } + callID := strings.TrimSpace(tc.ID) + if callID == "" { + callID = fmt.Sprintf("call_%d", len(items)) + } + items = append(items, responses.ResponseInputItemParamOfFunctionCall(arguments, callID, name)) + } + return items + case "tool": + if strings.TrimSpace(msg.ToolCallID) == "" { + return []responses.ResponseInputItemUnionParam{ + responses.ResponseInputItemParamOfMessage(msg.Content, responses.EasyInputMessageRoleUser), + } + } + return []responses.ResponseInputItemUnionParam{ + responses.ResponseInputItemParamOfFunctionCallOutput(msg.ToolCallID, msg.Content), + } + default: + return []responses.ResponseInputItemUnionParam{ + responses.ResponseInputItemParamOfMessage(msg.Content, responses.EasyInputMessageRoleUser), + } + } +} + func buildChatParams(messages []Message, tools []ToolDefinition, model string, opts map[string]interface{}) (openai.ChatCompletionNewParams, error) { params := openai.ChatCompletionNewParams{ Model: model, @@ -264,6 +366,74 @@ func mapChatCompletionResponse(resp *openai.ChatCompletion) *LLMResponse { } } +func mapResponsesAPIResponse(resp *responses.Response) *LLMResponse { + if resp == nil { + return &LLMResponse{ + Content: "", + FinishReason: "stop", + } + } + + content := resp.OutputText() + toolCalls := make([]ToolCall, 0) + for _, item := range resp.Output { + if item.Type != "function_call" { + continue + } + call := item.AsFunctionCall() + if strings.TrimSpace(call.Name) == "" { + continue + } + args := map[string]interface{}{} + if call.Arguments != "" { + if err := json.Unmarshal([]byte(call.Arguments), &args); err != nil { + args["raw"] = call.Arguments + } + } + id := strings.TrimSpace(call.CallID) + if id == "" { + id = strings.TrimSpace(call.ID) + } + if id == "" { + id = fmt.Sprintf("call_%d", len(toolCalls)+1) + } + toolCalls = append(toolCalls, ToolCall{ + ID: id, + Name: call.Name, + Arguments: args, + }) + } + + if len(toolCalls) == 0 { + compatCalls, cleanedContent := parseCompatFunctionCalls(content) + if len(compatCalls) > 0 { + toolCalls = compatCalls + content = cleanedContent + } + } + + finishReason := strings.TrimSpace(string(resp.Status)) + if finishReason == "" || finishReason == "completed" { + finishReason = "stop" + } + + var usage *UsageInfo + if resp.Usage.TotalTokens > 0 || resp.Usage.InputTokens > 0 || resp.Usage.OutputTokens > 0 { + usage = &UsageInfo{ + PromptTokens: int(resp.Usage.InputTokens), + CompletionTokens: int(resp.Usage.OutputTokens), + TotalTokens: int(resp.Usage.TotalTokens), + } + } + + return &LLMResponse{ + Content: content, + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: usage, + } +} + func int64FromOption(options map[string]interface{}, key string) (int64, bool) { if options == nil { return 0, false @@ -315,25 +485,21 @@ func normalizeAPIBase(raw string) string { return strings.TrimRight(trimmed, "/") } - path := strings.TrimRight(u.Path, "/") - for _, suffix := range []string{ - "/chat/completions", - "/chat", - "/responses", - } { - if strings.HasSuffix(path, suffix) { - path = strings.TrimSuffix(path, suffix) - break - } - } - - if path == "" { - path = "/" - } - u.Path = path + u.Path = strings.TrimRight(u.Path, "/") return strings.TrimRight(u.String(), "/") } +func normalizeProtocol(raw string) string { + switch strings.TrimSpace(raw) { + case "", ProtocolChatCompletions: + return ProtocolChatCompletions + case ProtocolResponses: + return ProtocolResponses + default: + return ProtocolChatCompletions + } +} + func parseCompatFunctionCalls(content string) ([]ToolCall, string) { if strings.TrimSpace(content) == "" || !strings.Contains(content, "") { return nil, content @@ -406,20 +572,124 @@ func extractTag(src string, tag string) string { } func (p *HTTPProvider) GetDefaultModel() string { - return "" + return p.defaultModel } func CreateProvider(cfg *config.Config) (LLMProvider, error) { - apiKey := cfg.Providers.Proxy.APIKey - apiBase := cfg.Providers.Proxy.APIBase - authMode := cfg.Providers.Proxy.Auth - - if apiBase == "" { - return nil, fmt.Errorf("no API base (CLIProxyAPI) configured") + name := strings.TrimSpace(cfg.Agents.Defaults.Proxy) + if name == "" { + name = "proxy" } - if cfg.Providers.Proxy.TimeoutSec <= 0 { - return nil, fmt.Errorf("invalid providers.proxy.timeout_sec: %d", cfg.Providers.Proxy.TimeoutSec) - } - - return NewHTTPProvider(apiKey, apiBase, authMode, time.Duration(cfg.Providers.Proxy.TimeoutSec)*time.Second), nil + return CreateProviderByName(cfg, name) +} + +func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error) { + pc, err := getProviderConfigByName(cfg, name) + if err != nil { + return nil, err + } + if pc.APIBase == "" { + return nil, fmt.Errorf("no API base configured for provider %q", name) + } + if pc.TimeoutSec <= 0 { + return nil, fmt.Errorf("invalid timeout_sec for provider %q: %d", name, pc.TimeoutSec) + } + defaultModel := "" + if len(pc.Models) > 0 { + defaultModel = pc.Models[0] + } + return NewHTTPProvider(pc.APIKey, pc.APIBase, pc.Protocol, defaultModel, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second), nil +} + +func CreateProviders(cfg *config.Config) (map[string]LLMProvider, error) { + configs := getAllProviderConfigs(cfg) + if len(configs) == 0 { + return nil, fmt.Errorf("no providers configured") + } + out := make(map[string]LLMProvider, len(configs)) + for name := range configs { + p, err := CreateProviderByName(cfg, name) + if err != nil { + return nil, err + } + out[name] = p + } + return out, nil +} + +func GetProviderModels(cfg *config.Config, name string) []string { + pc, err := getProviderConfigByName(cfg, name) + if err != nil { + return nil + } + out := make([]string, 0, len(pc.Models)) + seen := map[string]bool{} + for _, m := range pc.Models { + model := strings.TrimSpace(m) + if model == "" || seen[model] { + continue + } + seen[model] = true + out = append(out, model) + } + return out +} + +func ListProviderNames(cfg *config.Config) []string { + configs := getAllProviderConfigs(cfg) + if len(configs) == 0 { + return nil + } + names := make([]string, 0, len(configs)) + for name := range configs { + names = append(names, name) + } + return names +} + +func getAllProviderConfigs(cfg *config.Config) map[string]config.ProviderConfig { + out := map[string]config.ProviderConfig{} + if cfg == nil { + return out + } + includeLegacyProxy := len(cfg.Providers.Proxies) == 0 || strings.TrimSpace(cfg.Agents.Defaults.Proxy) == "proxy" || containsStringTrimmed(cfg.Agents.Defaults.ProxyFallbacks, "proxy") + if includeLegacyProxy && (cfg.Providers.Proxy.APIBase != "" || cfg.Providers.Proxy.APIKey != "" || cfg.Providers.Proxy.TimeoutSec > 0) { + out["proxy"] = cfg.Providers.Proxy + } + for name, pc := range cfg.Providers.Proxies { + trimmed := strings.TrimSpace(name) + if trimmed == "" { + continue + } + out[trimmed] = pc + } + return out +} + +func containsStringTrimmed(values []string, target string) bool { + t := strings.TrimSpace(target) + for _, v := range values { + if strings.TrimSpace(v) == t { + return true + } + } + return false +} + +func getProviderConfigByName(cfg *config.Config, name string) (config.ProviderConfig, error) { + if cfg == nil { + return config.ProviderConfig{}, fmt.Errorf("nil config") + } + trimmed := strings.TrimSpace(name) + if trimmed == "" { + return config.ProviderConfig{}, fmt.Errorf("empty provider name") + } + if trimmed == "proxy" { + return cfg.Providers.Proxy, nil + } + pc, ok := cfg.Providers.Proxies[trimmed] + if !ok { + return config.ProviderConfig{}, fmt.Errorf("provider %q not found", trimmed) + } + return pc, nil } diff --git a/pkg/providers/http_provider_test.go b/pkg/providers/provider_test.go similarity index 83% rename from pkg/providers/http_provider_test.go rename to pkg/providers/provider_test.go index eb225b2..902b00d 100644 --- a/pkg/providers/http_provider_test.go +++ b/pkg/providers/provider_test.go @@ -53,10 +53,10 @@ func TestNormalizeAPIBase_CompatibilityPaths(t *testing.T) { in string want string }{ - {"http://localhost:8080/v1/chat/completions", "http://localhost:8080/v1"}, - {"http://localhost:8080/v1/chat", "http://localhost:8080/v1"}, - {"http://localhost:8080/v1/responses", "http://localhost:8080/v1"}, + {"http://localhost:8080/v1/chat/completions", "http://localhost:8080/v1/chat/completions"}, + {"http://localhost:8080/v1/responses", "http://localhost:8080/v1/responses"}, {"http://localhost:8080/v1", "http://localhost:8080/v1"}, + {"http://localhost:8080/v1/", "http://localhost:8080/v1"}, } for _, tt := range tests { @@ -67,6 +67,25 @@ func TestNormalizeAPIBase_CompatibilityPaths(t *testing.T) { } } +func TestNormalizeProtocol(t *testing.T) { + tests := []struct { + in string + want string + }{ + {"", ProtocolChatCompletions}, + {"chat_completions", ProtocolChatCompletions}, + {"responses", ProtocolResponses}, + {"invalid", ProtocolChatCompletions}, + } + + for _, tt := range tests { + got := normalizeProtocol(tt.in) + if got != tt.want { + t.Fatalf("normalizeProtocol(%q) = %q, want %q", tt.in, got, tt.want) + } + } +} + func TestParseCompatFunctionCalls_NoMarkup(t *testing.T) { calls, cleaned := parseCompatFunctionCalls("hello") if len(calls) != 0 {