release: v0.2.0

This commit is contained in:
lpf
2026-03-11 19:00:19 +08:00
parent 1c0e463d07
commit 13108b0333
104 changed files with 6519 additions and 4296 deletions

View File

@@ -342,22 +342,20 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
subagentDigestDelay: 5 * time.Second,
subagentDigests: map[string]*subagentDigestState{},
}
if _, primaryModel := config.ParseProviderModelRef(cfg.Agents.Defaults.Model.Primary); strings.TrimSpace(primaryModel) != "" {
loop.model = strings.TrimSpace(primaryModel)
}
go loop.runSubagentDigestTicker()
// Initialize provider fallback chain (primary + proxy_fallbacks).
// Initialize provider fallback chain (primary + model fallbacks).
loop.providerPool = map[string]providers.LLMProvider{}
loop.providerNames = []string{}
primaryName := cfg.Agents.Defaults.Proxy
if primaryName == "" {
primaryName = "proxy"
}
primaryName := config.PrimaryProviderName(cfg)
loop.providerPool[primaryName] = provider
loop.providerNames = append(loop.providerNames, primaryName)
if strings.TrimSpace(primaryName) == "proxy" {
loop.providerResponses[primaryName] = cfg.Providers.Proxy.Responses
} else if pc, ok := cfg.Providers.Proxies[primaryName]; ok {
if pc, ok := config.ProviderConfigByName(cfg, primaryName); ok {
loop.providerResponses[primaryName] = pc.Responses
}
for _, name := range cfg.Agents.Defaults.ProxyFallbacks {
for _, name := range cfg.Agents.Defaults.Model.Fallbacks {
if name == "" {
continue
}
@@ -371,13 +369,13 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
if dup {
continue
}
if p2, err := providers.CreateProviderByName(cfg, name); err == nil {
loop.providerPool[name] = p2
loop.providerNames = append(loop.providerNames, name)
if pc, ok := cfg.Providers.Proxies[name]; ok {
loop.providerResponses[name] = pc.Responses
if p2, err := providers.CreateProviderByName(cfg, name); err == nil {
loop.providerPool[name] = p2
loop.providerNames = append(loop.providerNames, name)
if pc, ok := config.ProviderConfigByName(cfg, name); ok {
loop.providerResponses[name] = pc.Responses
}
}
}
}
// Inject recursive run logic so subagents can use full tool-calling flows.

View File

@@ -914,8 +914,8 @@ func collectRiskyConfigPaths(oldMap, newMap map[string]interface{}) []string {
"channels.telegram.token",
"channels.telegram.allow_from",
"channels.telegram.allow_chats",
"providers.proxy.api_base",
"providers.proxy.api_key",
"models.providers.openai.api_base",
"models.providers.openai.api_key",
"gateway.token",
"gateway.port",
}
@@ -923,9 +923,9 @@ func collectRiskyConfigPaths(oldMap, newMap map[string]interface{}) []string {
for _, path := range paths {
seen[path] = true
}
for _, name := range collectProviderProxyNames(oldMap, newMap) {
for _, name := range collectProviderNames(oldMap, newMap) {
for _, field := range []string{"api_base", "api_key"} {
path := "providers.proxies." + name + "." + field
path := "models.providers." + name + "." + field
if !seen[path] {
paths = append(paths, path)
seen[path] = true
@@ -935,13 +935,13 @@ func collectRiskyConfigPaths(oldMap, newMap map[string]interface{}) []string {
return paths
}
func collectProviderProxyNames(maps ...map[string]interface{}) []string {
func collectProviderNames(maps ...map[string]interface{}) []string {
seen := map[string]bool{}
names := make([]string, 0)
for _, root := range maps {
providers, _ := root["providers"].(map[string]interface{})
proxies, _ := providers["proxies"].(map[string]interface{})
for name := range proxies {
models, _ := root["models"].(map[string]interface{})
providers, _ := models["providers"].(map[string]interface{})
for name := range providers {
if strings.TrimSpace(name) == "" || seen[name] {
continue
}
@@ -1001,6 +1001,7 @@ func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Re
var body struct {
Provider string `json:"provider"`
AccountLabel string `json:"account_label"`
NetworkProxy string `json:"network_proxy"`
ProviderConfig cfgpkg.ProviderConfig `json:"provider_config"`
}
if r.Method == http.MethodPost {
@@ -1011,6 +1012,7 @@ func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Re
} else {
body.Provider = strings.TrimSpace(r.URL.Query().Get("provider"))
body.AccountLabel = strings.TrimSpace(r.URL.Query().Get("account_label"))
body.NetworkProxy = strings.TrimSpace(r.URL.Query().Get("network_proxy"))
}
cfg, pc, err := s.resolveProviderConfig(strings.TrimSpace(body.Provider), body.ProviderConfig)
if err != nil {
@@ -1027,7 +1029,10 @@ func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Re
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
flow, err := loginMgr.StartManualFlow()
flow, err := loginMgr.StartManualFlowWithOptions(providers.OAuthLoginOptions{
AccountLabel: body.AccountLabel,
NetworkProxy: firstNonEmptyString(strings.TrimSpace(body.NetworkProxy), strings.TrimSpace(pc.OAuth.NetworkProxy)),
})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
@@ -1044,6 +1049,7 @@ func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Re
"user_code": flow.UserCode,
"instructions": flow.Instructions,
"account_label": strings.TrimSpace(body.AccountLabel),
"network_proxy": strings.TrimSpace(body.NetworkProxy),
})
}
@@ -1061,6 +1067,7 @@ func (s *Server) handleWebUIProviderOAuthComplete(w http.ResponseWriter, r *http
FlowID string `json:"flow_id"`
CallbackURL string `json:"callback_url"`
AccountLabel string `json:"account_label"`
NetworkProxy string `json:"network_proxy"`
ProviderConfig cfgpkg.ProviderConfig `json:"provider_config"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
@@ -1091,6 +1098,7 @@ func (s *Server) handleWebUIProviderOAuthComplete(w http.ResponseWriter, r *http
}
session, models, err := loginMgr.CompleteManualFlowWithOptions(r.Context(), pc.APIBase, flow, body.CallbackURL, providers.OAuthLoginOptions{
AccountLabel: strings.TrimSpace(body.AccountLabel),
NetworkProxy: firstNonEmptyString(strings.TrimSpace(body.NetworkProxy), strings.TrimSpace(pc.OAuth.NetworkProxy)),
})
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
@@ -1111,6 +1119,7 @@ func (s *Server) handleWebUIProviderOAuthComplete(w http.ResponseWriter, r *http
"ok": true,
"account": session.Email,
"credential_file": session.CredentialFile,
"network_proxy": session.NetworkProxy,
"models": models,
})
}
@@ -1130,6 +1139,7 @@ func (s *Server) handleWebUIProviderOAuthImport(w http.ResponseWriter, r *http.R
}
providerName := strings.TrimSpace(r.FormValue("provider"))
accountLabel := strings.TrimSpace(r.FormValue("account_label"))
networkProxy := strings.TrimSpace(r.FormValue("network_proxy"))
inlineCfgRaw := strings.TrimSpace(r.FormValue("provider_config"))
var inlineCfg cfgpkg.ProviderConfig
if inlineCfgRaw != "" {
@@ -1165,6 +1175,7 @@ func (s *Server) handleWebUIProviderOAuthImport(w http.ResponseWriter, r *http.R
}
session, models, err := loginMgr.ImportAuthJSONWithOptions(r.Context(), pc.APIBase, header.Filename, data, providers.OAuthLoginOptions{
AccountLabel: accountLabel,
NetworkProxy: firstNonEmptyString(networkProxy, strings.TrimSpace(pc.OAuth.NetworkProxy)),
})
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
@@ -1185,6 +1196,7 @@ func (s *Server) handleWebUIProviderOAuthImport(w http.ResponseWriter, r *http.R
"ok": true,
"account": session.Email,
"credential_file": session.CredentialFile,
"network_proxy": session.NetworkProxy,
"models": models,
})
}
@@ -1401,10 +1413,7 @@ func (s *Server) loadProviderConfig(name string) (*cfgpkg.Config, cfgpkg.Provide
return nil, cfgpkg.ProviderConfig{}, err
}
providerName := strings.TrimSpace(name)
if providerName == "" || providerName == "proxy" {
return cfg, cfg.Providers.Proxy, nil
}
pc, ok := cfg.Providers.Proxies[providerName]
pc, ok := cfg.Models.Providers[providerName]
if !ok {
return nil, cfgpkg.ProviderConfig{}, fmt.Errorf("provider %q not found", providerName)
}
@@ -1435,14 +1444,10 @@ func (s *Server) saveProviderConfig(cfg *cfgpkg.Config, name string, pc cfgpkg.P
return fmt.Errorf("config is nil")
}
providerName := strings.TrimSpace(name)
if providerName == "" || providerName == "proxy" {
cfg.Providers.Proxy = pc
} else {
if cfg.Providers.Proxies == nil {
cfg.Providers.Proxies = map[string]cfgpkg.ProviderConfig{}
}
cfg.Providers.Proxies[providerName] = pc
if cfg.Models.Providers == nil {
cfg.Models.Providers = map[string]cfgpkg.ProviderConfig{}
}
cfg.Models.Providers[providerName] = pc
if err := cfgpkg.SaveConfig(s.configPath, cfg); err != nil {
return err
}
@@ -6052,7 +6057,7 @@ func hotReloadFieldInfo() []map[string]interface{} {
{"path": "logging.*", "name": "Logging", "description": "Log level, persistence, and related settings"},
{"path": "sentinel.*", "name": "Sentinel", "description": "Health checks and auto-heal behavior"},
{"path": "agents.*", "name": "Agent", "description": "Models, policies, and default behavior"},
{"path": "providers.*", "name": "Providers", "description": "LLM providers and proxy settings"},
{"path": "models.providers.*", "name": "Providers", "description": "LLM provider registry and auth settings"},
{"path": "tools.*", "name": "Tools", "description": "Tool toggles and runtime options"},
{"path": "channels.*", "name": "Channels", "description": "Telegram and other channel settings"},
{"path": "cron.*", "name": "Cron", "description": "Global cron runtime settings"},

View File

@@ -248,16 +248,20 @@ func TestHandleWebUIConfigRequiresConfirmForProviderAPIBaseChange(t *testing.T)
cfg := cfgpkg.DefaultConfig()
cfg.Logging.Enabled = false
cfg.Providers.Proxy.APIBase = "https://old.example/v1"
cfg.Providers.Proxy.APIKey = "test-key"
pc := cfg.Models.Providers["openai"]
pc.APIBase = "https://old.example/v1"
pc.APIKey = "test-key"
cfg.Models.Providers["openai"] = pc
if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil {
t.Fatalf("save config: %v", err)
}
bodyCfg := cfgpkg.DefaultConfig()
bodyCfg.Logging.Enabled = false
bodyCfg.Providers.Proxy.APIBase = "https://new.example/v1"
bodyCfg.Providers.Proxy.APIKey = "test-key"
bodyPC := bodyCfg.Models.Providers["openai"]
bodyPC.APIBase = "https://new.example/v1"
bodyPC.APIKey = "test-key"
bodyCfg.Models.Providers["openai"] = bodyPC
body, err := json.Marshal(bodyCfg)
if err != nil {
t.Fatalf("marshal body: %v", err)
@@ -278,8 +282,8 @@ func TestHandleWebUIConfigRequiresConfirmForProviderAPIBaseChange(t *testing.T)
if !strings.Contains(rec.Body.String(), `"requires_confirm":true`) {
t.Fatalf("expected requires_confirm response, got: %s", rec.Body.String())
}
if !strings.Contains(rec.Body.String(), `providers.proxy.api_base`) {
t.Fatalf("expected providers.proxy.api_base in changed_fields, got: %s", rec.Body.String())
if !strings.Contains(rec.Body.String(), `models.providers.openai.api_base`) {
t.Fatalf("expected models.providers.openai.api_base in changed_fields, got: %s", rec.Body.String())
}
}
@@ -291,7 +295,7 @@ func TestHandleWebUIConfigRequiresConfirmForCustomProviderSecretChange(t *testin
cfg := cfgpkg.DefaultConfig()
cfg.Logging.Enabled = false
cfg.Providers.Proxies["backup"] = cfgpkg.ProviderConfig{
cfg.Models.Providers["backup"] = cfgpkg.ProviderConfig{
APIBase: "https://backup.example/v1",
APIKey: "old-secret",
Models: []string{"backup-model"},
@@ -304,7 +308,7 @@ func TestHandleWebUIConfigRequiresConfirmForCustomProviderSecretChange(t *testin
bodyCfg := cfgpkg.DefaultConfig()
bodyCfg.Logging.Enabled = false
bodyCfg.Providers.Proxies["backup"] = cfgpkg.ProviderConfig{
bodyCfg.Models.Providers["backup"] = cfgpkg.ProviderConfig{
APIBase: "https://backup.example/v1",
APIKey: "new-secret",
Models: []string{"backup-model"},
@@ -331,8 +335,8 @@ func TestHandleWebUIConfigRequiresConfirmForCustomProviderSecretChange(t *testin
if !strings.Contains(rec.Body.String(), `"requires_confirm":true`) {
t.Fatalf("expected requires_confirm response, got: %s", rec.Body.String())
}
if !strings.Contains(rec.Body.String(), `providers.proxies.backup.api_key`) {
t.Fatalf("expected providers.proxies.backup.api_key in changed_fields, got: %s", rec.Body.String())
if !strings.Contains(rec.Body.String(), `models.providers.backup.api_key`) {
t.Fatalf("expected models.providers.backup.api_key in changed_fields, got: %s", rec.Body.String())
}
}

View File

@@ -9,6 +9,7 @@ import (
"io"
"os"
"path/filepath"
"strings"
"sync"
"time"
@@ -16,16 +17,16 @@ import (
)
type Config struct {
Agents AgentsConfig `json:"agents"`
Channels ChannelsConfig `json:"channels"`
Providers ProvidersConfig `json:"providers"`
Gateway GatewayConfig `json:"gateway"`
Cron CronConfig `json:"cron"`
Tools ToolsConfig `json:"tools"`
Logging LoggingConfig `json:"logging"`
Sentinel SentinelConfig `json:"sentinel"`
Memory MemoryConfig `json:"memory"`
mu sync.RWMutex
Agents AgentsConfig `json:"agents"`
Channels ChannelsConfig `json:"channels"`
Models ModelsConfig `json:"models,omitempty"`
Gateway GatewayConfig `json:"gateway"`
Cron CronConfig `json:"cron"`
Tools ToolsConfig `json:"tools"`
Logging LoggingConfig `json:"logging"`
Sentinel SentinelConfig `json:"sentinel"`
Memory MemoryConfig `json:"memory"`
mu sync.RWMutex
}
type AgentsConfig struct {
@@ -107,7 +108,7 @@ type SubagentToolsConfig struct {
}
type SubagentRuntimeConfig struct {
Proxy string `json:"proxy,omitempty"`
Provider string `json:"provider,omitempty"`
Model string `json:"model,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TimeoutSec int `json:"timeout_sec,omitempty"`
@@ -120,8 +121,7 @@ type SubagentRuntimeConfig struct {
type AgentDefaults struct {
Workspace string `json:"workspace" env:"CLAWGO_AGENTS_DEFAULTS_WORKSPACE"`
Proxy string `json:"proxy" env:"CLAWGO_AGENTS_DEFAULTS_PROXY"`
ProxyFallbacks []string `json:"proxy_fallbacks" env:"CLAWGO_AGENTS_DEFAULTS_PROXY_FALLBACKS"`
Model AgentModelDefaults `json:"model,omitempty"`
MaxTokens int `json:"max_tokens" env:"CLAWGO_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature float64 `json:"temperature" env:"CLAWGO_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"CLAWGO_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
@@ -131,6 +131,11 @@ type AgentDefaults struct {
SummaryPolicy SystemSummaryPolicyConfig `json:"summary_policy"`
}
type AgentModelDefaults struct {
Primary string `json:"primary,omitempty" env:"CLAWGO_AGENTS_DEFAULTS_MODEL_PRIMARY"`
Fallbacks []string `json:"fallbacks,omitempty" env:"CLAWGO_AGENTS_DEFAULTS_MODEL_FALLBACKS"`
}
type HeartbeatConfig struct {
Enabled bool `json:"enabled" env:"CLAWGO_AGENTS_DEFAULTS_HEARTBEAT_ENABLED"`
EverySec int `json:"every_sec" env:"CLAWGO_AGENTS_DEFAULTS_HEARTBEAT_EVERY_SEC"`
@@ -234,52 +239,8 @@ type DingTalkConfig struct {
AllowFrom []string `json:"allow_from" env:"CLAWGO_CHANNELS_DINGTALK_ALLOW_FROM"`
}
type ProvidersConfig struct {
Proxy ProviderConfig `json:"proxy"`
Proxies map[string]ProviderConfig `json:"proxies"`
}
type providerProxyItem struct {
Name string `json:"name"`
ProviderConfig
}
func (p *ProvidersConfig) UnmarshalJSON(data []byte) error {
var tmp struct {
Proxy ProviderConfig `json:"proxy"`
Proxies json.RawMessage `json:"proxies"`
}
if err := json.Unmarshal(data, &tmp); err != nil {
return err
}
p.Proxy = tmp.Proxy
p.Proxies = map[string]ProviderConfig{}
if len(bytes.TrimSpace(tmp.Proxies)) == 0 || string(bytes.TrimSpace(tmp.Proxies)) == "null" {
return nil
}
// Preferred format: object map
var asMap map[string]ProviderConfig
if err := json.Unmarshal(tmp.Proxies, &asMap); err == nil {
for k, v := range asMap {
if k == "" {
continue
}
p.Proxies[k] = v
}
return nil
}
// Compatibility format: array [{name, ...provider fields...}]
var asArr []providerProxyItem
if err := json.Unmarshal(tmp.Proxies, &asArr); err == nil {
for _, it := range asArr {
if it.Name == "" {
continue
}
p.Proxies[it.Name] = it.ProviderConfig
}
return nil
}
return fmt.Errorf("providers.proxies must be object map or array of {name,...}")
type ModelsConfig struct {
Providers map[string]ProviderConfig `json:"providers,omitempty"`
}
type ProviderConfig struct {
@@ -298,6 +259,7 @@ type ProviderConfig struct {
type ProviderOAuthConfig struct {
Provider string `json:"provider,omitempty"`
NetworkProxy string `json:"network_proxy,omitempty"`
CredentialFile string `json:"credential_file,omitempty"`
CredentialFiles []string `json:"credential_files,omitempty"`
CallbackPort int `json:"callback_port,omitempty"`
@@ -307,7 +269,6 @@ type ProviderOAuthConfig struct {
TokenURL string `json:"token_url,omitempty"`
RedirectURL string `json:"redirect_url,omitempty"`
Scopes []string `json:"scopes,omitempty"`
HybridPriority string `json:"hybrid_priority,omitempty"`
CooldownSec int `json:"cooldown_sec,omitempty"`
RefreshScanSec int `json:"refresh_scan_sec,omitempty"`
RefreshLeadSec int `json:"refresh_lead_sec,omitempty"`
@@ -484,8 +445,7 @@ func DefaultConfig() *Config {
Agents: AgentsConfig{
Defaults: AgentDefaults{
Workspace: filepath.Join(configDir, "workspace"),
Proxy: "proxy",
ProxyFallbacks: []string{},
Model: AgentModelDefaults{Primary: "openai/gpt-5.4", Fallbacks: []string{}},
MaxTokens: 8192,
Temperature: 0.7,
MaxToolIterations: 20,
@@ -599,13 +559,14 @@ func DefaultConfig() *Config {
AllowFrom: []string{},
},
},
Providers: ProvidersConfig{
Proxy: ProviderConfig{
APIBase: "http://localhost:8080/v1",
Models: []string{"glm-4.7"},
TimeoutSec: 90,
Models: ModelsConfig{
Providers: map[string]ProviderConfig{
"openai": {
APIBase: "https://api.openai.com/v1",
Models: []string{"gpt-5.4"},
TimeoutSec: 90,
},
},
Proxies: map[string]ProviderConfig{},
},
Gateway: GatewayConfig{
Host: "0.0.0.0",
@@ -699,6 +660,60 @@ func DefaultConfig() *Config {
}
}
func ParseProviderModelRef(raw string) (provider string, model string) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", ""
}
if idx := strings.Index(trimmed, "/"); idx > 0 {
return strings.TrimSpace(trimmed[:idx]), strings.TrimSpace(trimmed[idx+1:])
}
return "", trimmed
}
func AllProviderConfigs(cfg *Config) map[string]ProviderConfig {
out := map[string]ProviderConfig{}
if cfg == nil {
return out
}
for name, pc := range cfg.Models.Providers {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
continue
}
out[trimmed] = pc
}
return out
}
func ProviderConfigByName(cfg *Config, name string) (ProviderConfig, bool) {
if cfg == nil {
return ProviderConfig{}, false
}
pc, ok := AllProviderConfigs(cfg)[strings.TrimSpace(name)]
return pc, ok
}
func ProviderExists(cfg *Config, name string) bool {
_, ok := ProviderConfigByName(cfg, name)
return ok
}
func PrimaryProviderName(cfg *Config) string {
if cfg == nil {
return "openai"
}
if provider, _ := ParseProviderModelRef(cfg.Agents.Defaults.Model.Primary); provider != "" {
return provider
}
for name := range cfg.Models.Providers {
if trimmed := strings.TrimSpace(name); trimmed != "" {
return trimmed
}
}
return "openai"
}
func generateGatewayToken() string {
var buf [16]byte
if _, err := rand.Read(buf[:]); err != nil {
@@ -771,13 +786,19 @@ func (c *Config) WorkspacePath() string {
func (c *Config) GetAPIKey() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.Providers.Proxy.APIKey
if pc, ok := c.Models.Providers[PrimaryProviderName(c)]; ok {
return pc.APIKey
}
return ""
}
func (c *Config) GetAPIBase() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.Providers.Proxy.APIBase
if pc, ok := c.Models.Providers[PrimaryProviderName(c)]; ok {
return pc.APIBase
}
return ""
}
func (c *Config) LogFilePath() string {

View File

@@ -87,30 +87,33 @@ func Validate(cfg *Config) []error {
}
}
if len(cfg.Providers.Proxies) == 0 {
errs = append(errs, validateProviderConfig("providers.proxy", cfg.Providers.Proxy)...)
} else {
for name, p := range cfg.Providers.Proxies {
errs = append(errs, validateProviderConfig("providers.proxies."+name, p)...)
for name, p := range cfg.Models.Providers {
errs = append(errs, validateProviderConfig("models.providers."+name, p)...)
}
if len(cfg.Models.Providers) == 0 {
errs = append(errs, fmt.Errorf("models.providers must contain at least one provider"))
}
for _, name := range cfg.Agents.Defaults.Model.Fallbacks {
if !ProviderExists(cfg, name) {
errs = append(errs, fmt.Errorf("agents.defaults.model.fallbacks contains unknown provider %q", name))
}
}
if cfg.Agents.Defaults.Proxy != "" {
if !providerExists(cfg, cfg.Agents.Defaults.Proxy) {
errs = append(errs, fmt.Errorf("agents.defaults.proxy %q not found in providers", cfg.Agents.Defaults.Proxy))
if primaryRef := strings.TrimSpace(cfg.Agents.Defaults.Model.Primary); primaryRef != "" {
providerName, modelName := ParseProviderModelRef(primaryRef)
if providerName == "" {
providerName = PrimaryProviderName(cfg)
}
}
for _, name := range cfg.Agents.Defaults.ProxyFallbacks {
if !providerExists(cfg, name) {
errs = append(errs, fmt.Errorf("agents.defaults.proxy_fallbacks contains unknown proxy %q", name))
if !ProviderExists(cfg, providerName) {
errs = append(errs, fmt.Errorf("agents.defaults.model.primary %q references unknown provider %q", primaryRef, providerName))
}
if strings.TrimSpace(modelName) == "" {
errs = append(errs, fmt.Errorf("agents.defaults.model.primary must include a model, expected provider/model"))
}
}
if cfg.Agents.Defaults.ContextCompaction.Enabled && cfg.Agents.Defaults.ContextCompaction.Mode == "responses_compact" {
active := cfg.Agents.Defaults.Proxy
if active == "" {
active = "proxy"
}
if pc, ok := providerConfigByName(cfg, active); !ok || !pc.SupportsResponsesCompact {
errs = append(errs, fmt.Errorf("context_compaction.mode=responses_compact requires active proxy %q with supports_responses_compact=true", active))
active := PrimaryProviderName(cfg)
if pc, ok := ProviderConfigByName(cfg, active); !ok || !pc.SupportsResponsesCompact {
errs = append(errs, fmt.Errorf("context_compaction.mode=responses_compact requires active provider %q with supports_responses_compact=true", active))
}
}
errs = append(errs, validateAgentRouter(cfg)...)
@@ -474,8 +477,8 @@ func validateSubagents(cfg *Config) []error {
errs = append(errs, fmt.Errorf("agents.subagents.%s.system_prompt_file must stay within workspace", id))
}
}
if proxy := strings.TrimSpace(raw.Runtime.Proxy); proxy != "" && !providerExists(cfg, proxy) {
errs = append(errs, fmt.Errorf("agents.subagents.%s.runtime.proxy %q not found in providers", id, proxy))
if provider := strings.TrimSpace(raw.Runtime.Provider); provider != "" && !ProviderExists(cfg, provider) {
errs = append(errs, fmt.Errorf("agents.subagents.%s.runtime.provider %q not found in providers", id, provider))
}
for _, sender := range raw.AcceptFrom {
sender = strings.TrimSpace(sender)
@@ -562,13 +565,6 @@ func validateProviderConfig(path string, p ProviderConfig) []error {
errs = append(errs, fmt.Errorf("%s.oauth.provider is required when auth=hybrid", path))
}
}
if p.OAuth.HybridPriority != "" {
switch strings.ToLower(strings.TrimSpace(p.OAuth.HybridPriority)) {
case "api_first", "oauth_first":
default:
errs = append(errs, fmt.Errorf("%s.oauth.hybrid_priority must be one of: api_first, oauth_first", path))
}
}
if p.OAuth.CooldownSec < 0 {
errs = append(errs, fmt.Errorf("%s.oauth.cooldown_sec must be >= 0", path))
}
@@ -587,25 +583,6 @@ func validateProviderConfig(path string, p ProviderConfig) []error {
return errs
}
func providerExists(cfg *Config, name string) bool {
if name == "proxy" && cfg.Providers.Proxy.APIBase != "" {
return true
}
if cfg.Providers.Proxies == nil {
return false
}
_, ok := cfg.Providers.Proxies[name]
return ok
}
func providerConfigByName(cfg *Config, name string) (ProviderConfig, bool) {
if strings.TrimSpace(name) == "proxy" {
return cfg.Providers.Proxy, true
}
pc, ok := cfg.Providers.Proxies[name]
return pc, ok
}
func validateNonEmptyStringList(path string, values []string) []error {
if len(values) == 0 {
return nil

View File

@@ -34,7 +34,7 @@ func TestValidateSubagentsAllowsKnownPeers(t *testing.T) {
AcceptFrom: []string{"main"},
CanTalkTo: []string{"main"},
Runtime: SubagentRuntimeConfig{
Proxy: "proxy",
Provider: "openai",
},
}
@@ -66,7 +66,7 @@ func TestValidateSubagentsRejectsAbsolutePromptFile(t *testing.T) {
Enabled: true,
SystemPromptFile: "/tmp/AGENT.md",
Runtime: SubagentRuntimeConfig{
Proxy: "proxy",
Provider: "openai",
},
}
@@ -82,7 +82,7 @@ func TestValidateSubagentsRequiresPromptFileWhenEnabled(t *testing.T) {
cfg.Agents.Subagents["coder"] = SubagentConfig{
Enabled: true,
Runtime: SubagentRuntimeConfig{
Proxy: "proxy",
Provider: "openai",
},
}
@@ -123,7 +123,7 @@ func TestValidateSubagentsRejectsInvalidNotifyMainPolicy(t *testing.T) {
SystemPromptFile: "agents/coder/AGENT.md",
NotifyMainPolicy: "loud",
Runtime: SubagentRuntimeConfig{
Proxy: "proxy",
Provider: "openai",
},
}
@@ -276,9 +276,11 @@ func TestValidateProviderOAuthAllowsEmptyModelsBeforeLogin(t *testing.T) {
t.Parallel()
cfg := DefaultConfig()
cfg.Providers.Proxy.Auth = "oauth"
cfg.Providers.Proxy.Models = nil
cfg.Providers.Proxy.OAuth = ProviderOAuthConfig{Provider: "codex"}
pc := cfg.Models.Providers["openai"]
pc.Auth = "oauth"
pc.Models = nil
pc.OAuth = ProviderOAuthConfig{Provider: "codex"}
cfg.Models.Providers["openai"] = pc
if errs := Validate(cfg); len(errs) != 0 {
t.Fatalf("expected oauth provider config to be valid before model sync, got %v", errs)
@@ -289,9 +291,11 @@ func TestValidateProviderOAuthRequiresProviderName(t *testing.T) {
t.Parallel()
cfg := DefaultConfig()
cfg.Providers.Proxy.Auth = "oauth"
cfg.Providers.Proxy.Models = nil
cfg.Providers.Proxy.OAuth = ProviderOAuthConfig{}
pc := cfg.Models.Providers["openai"]
pc.Auth = "oauth"
pc.Models = nil
pc.OAuth = ProviderOAuthConfig{}
cfg.Models.Providers["openai"] = pc
errs := Validate(cfg)
if len(errs) == 0 {
@@ -299,7 +303,7 @@ func TestValidateProviderOAuthRequiresProviderName(t *testing.T) {
}
found := false
for _, err := range errs {
if strings.Contains(err.Error(), "providers.proxy.oauth.provider") {
if strings.Contains(err.Error(), "models.providers.openai.oauth.provider") {
found = true
break
}
@@ -313,10 +317,12 @@ func TestValidateProviderHybridAllowsEmptyModels(t *testing.T) {
t.Parallel()
cfg := DefaultConfig()
cfg.Providers.Proxy.Auth = "hybrid"
cfg.Providers.Proxy.APIKey = "sk-test"
cfg.Providers.Proxy.Models = nil
cfg.Providers.Proxy.OAuth = ProviderOAuthConfig{Provider: "codex"}
pc := cfg.Models.Providers["openai"]
pc.Auth = "hybrid"
pc.APIKey = "sk-test"
pc.Models = nil
pc.OAuth = ProviderOAuthConfig{Provider: "codex"}
cfg.Models.Providers["openai"] = pc
if errs := Validate(cfg); len(errs) != 0 {
t.Fatalf("expected hybrid provider config to be valid before model sync, got %v", errs)
@@ -327,10 +333,12 @@ func TestValidateProviderHybridRequiresOAuthProvider(t *testing.T) {
t.Parallel()
cfg := DefaultConfig()
cfg.Providers.Proxy.Auth = "hybrid"
cfg.Providers.Proxy.APIKey = "sk-test"
cfg.Providers.Proxy.Models = nil
cfg.Providers.Proxy.OAuth = ProviderOAuthConfig{}
pc := cfg.Models.Providers["openai"]
pc.Auth = "hybrid"
pc.APIKey = "sk-test"
pc.Models = nil
pc.OAuth = ProviderOAuthConfig{}
cfg.Models.Providers["openai"] = pc
errs := Validate(cfg)
if len(errs) == 0 {
@@ -338,7 +346,7 @@ func TestValidateProviderHybridRequiresOAuthProvider(t *testing.T) {
}
found := false
for _, err := range errs {
if strings.Contains(err.Error(), "providers.proxy.oauth.provider") {
if strings.Contains(err.Error(), "models.providers.openai.oauth.provider") {
found = true
break
}
@@ -347,30 +355,3 @@ func TestValidateProviderHybridRequiresOAuthProvider(t *testing.T) {
t.Fatalf("expected oauth.provider validation error, got %v", errs)
}
}
func TestValidateProviderHybridPriorityRejectsInvalidValue(t *testing.T) {
t.Parallel()
cfg := DefaultConfig()
cfg.Providers.Proxy.Auth = "hybrid"
cfg.Providers.Proxy.APIKey = "sk-test"
cfg.Providers.Proxy.OAuth = ProviderOAuthConfig{
Provider: "codex",
HybridPriority: "random_first",
}
errs := Validate(cfg)
if len(errs) == 0 {
t.Fatalf("expected validation errors")
}
found := false
for _, err := range errs {
if strings.Contains(err.Error(), "oauth.hybrid_priority") {
found = true
break
}
}
if !found {
t.Fatalf("expected oauth.hybrid_priority validation error, got %v", errs)
}
}

View File

@@ -1,6 +1,7 @@
package providers
import (
"context"
"net"
"net/http"
"strings"
@@ -16,16 +17,25 @@ type anthropicOAuthRoundTripper struct {
connections map[string]*http2.ClientConn
pending map[string]*sync.Cond
dialer net.Dialer
dialContext func(context.Context, string, string) (net.Conn, error)
}
func newAnthropicOAuthHTTPClient(timeout time.Duration) *http.Client {
func newAnthropicOAuthHTTPClient(timeout time.Duration, proxyURL string) (*http.Client, error) {
rt, err := newAnthropicOAuthRoundTripper(proxyURL)
if err != nil {
return nil, err
}
return &http.Client{
Timeout: timeout,
Transport: newAnthropicOAuthRoundTripper(),
}
Transport: rt,
}, nil
}
func newAnthropicOAuthRoundTripper() *anthropicOAuthRoundTripper {
func newAnthropicOAuthRoundTripper(proxyURL string) (*anthropicOAuthRoundTripper, error) {
dialContext, err := proxyDialContext(proxyURL)
if err != nil {
return nil, err
}
return &anthropicOAuthRoundTripper{
connections: map[string]*http2.ClientConn{},
pending: map[string]*sync.Cond{},
@@ -33,7 +43,8 @@ func newAnthropicOAuthRoundTripper() *anthropicOAuthRoundTripper {
Timeout: 15 * time.Second,
KeepAlive: 30 * time.Second,
},
}
dialContext: dialContext,
}, nil
}
func (t *anthropicOAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -89,7 +100,11 @@ func (t *anthropicOAuthRoundTripper) getOrCreateConnection(host, addr string) (*
}
func (t *anthropicOAuthRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
rawConn, err := t.dialer.Dial("tcp", addr)
dialContext := t.dialContext
if dialContext == nil {
dialContext = t.dialer.DialContext
}
rawConn, err := dialContext(context.Background(), "tcp", addr)
if err != nil {
return nil, err
}

View File

@@ -665,7 +665,7 @@ func (p *HTTPProvider) postJSONStream(ctx context.Context, endpoint string, payl
req.Header.Set("Accept", "text/event-stream")
applyAttemptAuth(req, attempt)
body, status, ctype, quotaHit, err := p.doStreamAttempt(req, onEvent)
body, status, ctype, quotaHit, err := p.doStreamAttempt(req, attempt, onEvent)
if err != nil {
return nil, 0, "", err
}
@@ -707,7 +707,7 @@ func (p *HTTPProvider) postJSON(ctx context.Context, endpoint string, payload in
req.Header.Set("Content-Type", "application/json")
applyAttemptAuth(req, attempt)
body, status, ctype, err := p.doJSONAttempt(req)
body, status, ctype, err := p.doJSONAttempt(req, attempt)
if err != nil {
return nil, 0, "", err
}
@@ -753,7 +753,7 @@ func (p *HTTPProvider) authAttempts(ctx context.Context) ([]authAttempt, error)
for _, attempt := range attempts {
oauthAttempts = append(oauthAttempts, authAttempt{session: attempt.Session, token: attempt.Token, kind: "oauth"})
}
if mode == "hybrid" && apiReady && p.oauth.cfg.HybridPriority != "oauth_first" {
if mode == "hybrid" && apiReady {
out = append(out, apiAttempt)
}
if len(attempts) == 0 {
@@ -764,9 +764,6 @@ func (p *HTTPProvider) authAttempts(ctx context.Context) ([]authAttempt, error)
return nil, fmt.Errorf("oauth session not found, run `clawgo provider login` first")
}
out = append(out, oauthAttempts...)
if mode == "hybrid" && apiReady && p.oauth.cfg.HybridPriority == "oauth_first" {
out = append(out, apiAttempt)
}
p.updateCandidateOrder(out)
return out, nil
}
@@ -833,8 +830,19 @@ func applyAttemptAuth(req *http.Request, attempt authAttempt) {
req.Header.Set("Authorization", "Bearer "+attempt.token)
}
func (p *HTTPProvider) doJSONAttempt(req *http.Request) ([]byte, int, string, error) {
resp, err := p.httpClient.Do(req)
func (p *HTTPProvider) httpClientForAttempt(attempt authAttempt) (*http.Client, error) {
if attempt.kind == "oauth" && attempt.session != nil && p.oauth != nil {
return p.oauth.httpClientForSession(attempt.session)
}
return p.httpClient, nil
}
func (p *HTTPProvider) doJSONAttempt(req *http.Request, attempt authAttempt) ([]byte, int, string, error) {
client, err := p.httpClientForAttempt(attempt)
if err != nil {
return nil, 0, "", err
}
resp, err := client.Do(req)
if err != nil {
return nil, 0, "", fmt.Errorf("failed to send request: %w", err)
}
@@ -846,8 +854,12 @@ func (p *HTTPProvider) doJSONAttempt(req *http.Request) ([]byte, int, string, er
return body, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil
}
func (p *HTTPProvider) doStreamAttempt(req *http.Request, onEvent func(string)) ([]byte, int, string, bool, error) {
resp, err := p.httpClient.Do(req)
func (p *HTTPProvider) doStreamAttempt(req *http.Request, attempt authAttempt, onEvent func(string)) ([]byte, int, string, bool, error) {
client, err := p.httpClientForAttempt(attempt)
if err != nil {
return nil, 0, "", false, err
}
resp, err := client.Do(req)
if err != nil {
return nil, 0, "", false, fmt.Errorf("failed to send request: %w", err)
}
@@ -1437,17 +1449,10 @@ func buildProviderCandidateOrder(_ string, pc config.ProviderConfig, accounts []
case "oauth":
out = append(out, oauthAvailable...)
case "hybrid":
if strings.EqualFold(strings.TrimSpace(pc.OAuth.HybridPriority), "oauth_first") {
out = append(out, oauthAvailable...)
if apiCandidate.Target != "" && apiCandidate.Available {
out = append(out, apiCandidate)
}
} else {
if apiCandidate.Target != "" && apiCandidate.Available {
out = append(out, apiCandidate)
}
out = append(out, oauthAvailable...)
if apiCandidate.Target != "" && apiCandidate.Available {
out = append(out, apiCandidate)
}
out = append(out, oauthAvailable...)
case "none":
default:
if apiCandidate.Target != "" {
@@ -2131,11 +2136,16 @@ func (p *HTTPProvider) BuildSummaryViaResponsesCompact(ctx context.Context, mode
}
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
name := strings.TrimSpace(cfg.Agents.Defaults.Proxy)
if name == "" {
name = "proxy"
name := config.PrimaryProviderName(cfg)
provider, err := CreateProviderByName(cfg, name)
if err != nil {
return nil, err
}
return CreateProviderByName(cfg, name)
_, model := config.ParseProviderModelRef(cfg.Agents.Defaults.Model.Primary)
if hp, ok := provider.(*HTTPProvider); ok && strings.TrimSpace(model) != "" {
hp.defaultModel = strings.TrimSpace(model)
}
return provider, nil
}
func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error) {
@@ -2219,48 +2229,12 @@ func ListProviderNames(cfg *config.Config) []string {
}
func getAllProviderConfigs(cfg *config.Config) map[string]config.ProviderConfig {
out := map[string]config.ProviderConfig{}
if cfg == nil {
return out
}
includeLegacyProxy := len(cfg.Providers.Proxies) == 0 || strings.TrimSpace(cfg.Agents.Defaults.Proxy) == "proxy" || containsStringTrimmed(cfg.Agents.Defaults.ProxyFallbacks, "proxy")
if includeLegacyProxy && (cfg.Providers.Proxy.APIBase != "" || cfg.Providers.Proxy.APIKey != "" || cfg.Providers.Proxy.TimeoutSec > 0) {
out["proxy"] = cfg.Providers.Proxy
}
for name, pc := range cfg.Providers.Proxies {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
continue
}
out[trimmed] = pc
}
return out
}
func containsStringTrimmed(values []string, target string) bool {
t := strings.TrimSpace(target)
for _, v := range values {
if strings.TrimSpace(v) == t {
return true
}
}
return false
return config.AllProviderConfigs(cfg)
}
func getProviderConfigByName(cfg *config.Config, name string) (config.ProviderConfig, error) {
if cfg == nil {
return config.ProviderConfig{}, fmt.Errorf("nil config")
if pc, ok := config.ProviderConfigByName(cfg, name); ok {
return pc, nil
}
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return config.ProviderConfig{}, fmt.Errorf("empty provider name")
}
if trimmed == "proxy" {
return cfg.Providers.Proxy, nil
}
pc, ok := cfg.Providers.Proxies[trimmed]
if !ok {
return config.ProviderConfig{}, fmt.Errorf("provider %q not found", trimmed)
}
return pc, nil
return config.ProviderConfig{}, fmt.Errorf("provider %q not found", strings.TrimSpace(name))
}

156
pkg/providers/http_proxy.go Normal file
View File

@@ -0,0 +1,156 @@
package providers
import (
"bufio"
"context"
stdtls "crypto/tls"
"encoding/base64"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"time"
xproxy "golang.org/x/net/proxy"
)
func normalizeOptionalProxyURL(raw string) (string, error) {
value := strings.TrimSpace(raw)
if value == "" {
return "", nil
}
if !strings.Contains(value, "://") {
value = "http://" + value
}
parsed, err := url.Parse(value)
if err != nil {
return "", fmt.Errorf("invalid network proxy: %w", err)
}
if parsed.Scheme == "" || parsed.Host == "" {
return "", fmt.Errorf("invalid network proxy: host is required")
}
switch strings.ToLower(strings.TrimSpace(parsed.Scheme)) {
case "http", "https", "socks5", "socks5h":
return parsed.String(), nil
default:
return "", fmt.Errorf("invalid network proxy: unsupported scheme %q", parsed.Scheme)
}
}
func maskedProxyURL(raw string) string {
normalized, err := normalizeOptionalProxyURL(raw)
if err != nil || normalized == "" {
return ""
}
parsed, err := url.Parse(normalized)
if err != nil {
return ""
}
if parsed.User != nil {
username := parsed.User.Username()
if username != "" {
parsed.User = url.UserPassword(username, "***")
} else {
parsed.User = url.User("***")
}
}
return parsed.String()
}
func proxyDialContext(proxyRaw string) (func(context.Context, string, string) (net.Conn, error), error) {
normalized, err := normalizeOptionalProxyURL(proxyRaw)
if err != nil {
return nil, err
}
if normalized == "" {
dialer := &net.Dialer{Timeout: 15 * time.Second, KeepAlive: 30 * time.Second}
return dialer.DialContext, nil
}
parsed, err := url.Parse(normalized)
if err != nil {
return nil, err
}
switch strings.ToLower(strings.TrimSpace(parsed.Scheme)) {
case "socks5", "socks5h":
base := &net.Dialer{Timeout: 15 * time.Second, KeepAlive: 30 * time.Second}
dialer, err := xproxy.FromURL(parsed, base)
if err != nil {
return nil, fmt.Errorf("configure socks proxy failed: %w", err)
}
if ctxDialer, ok := dialer.(xproxy.ContextDialer); ok {
return ctxDialer.DialContext, nil
}
return func(ctx context.Context, network, addr string) (net.Conn, error) {
type dialResult struct {
conn net.Conn
err error
}
ch := make(chan dialResult, 1)
go func() {
conn, err := dialer.Dial(network, addr)
ch <- dialResult{conn: conn, err: err}
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case res := <-ch:
return res.conn, res.err
}
}, nil
case "http", "https":
base := &net.Dialer{Timeout: 15 * time.Second, KeepAlive: 30 * time.Second}
return func(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := base.DialContext(ctx, "tcp", parsed.Host)
if err != nil {
return nil, err
}
if strings.EqualFold(parsed.Scheme, "https") {
tlsConn := stdtls.Client(conn, &stdtls.Config{ServerName: parsed.Hostname()})
if err := tlsConn.HandshakeContext(ctx); err != nil {
_ = conn.Close()
return nil, err
}
conn = tlsConn
}
connectReq := buildProxyConnectRequest(parsed, addr)
if _, err := conn.Write([]byte(connectReq)); err != nil {
_ = conn.Close()
return nil, err
}
br := bufio.NewReader(conn)
resp, err := http.ReadResponse(br, &http.Request{Method: http.MethodConnect})
if err != nil {
_ = conn.Close()
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
_ = conn.Close()
return nil, fmt.Errorf("proxy connect failed: status=%d", resp.StatusCode)
}
return conn, nil
}, nil
default:
return nil, fmt.Errorf("invalid network proxy: unsupported scheme %q", parsed.Scheme)
}
}
func buildProxyConnectRequest(proxyURL *url.URL, targetAddr string) string {
var b strings.Builder
b.WriteString("CONNECT ")
b.WriteString(targetAddr)
b.WriteString(" HTTP/1.1\r\nHost: ")
b.WriteString(targetAddr)
b.WriteString("\r\n")
if proxyURL != nil && proxyURL.User != nil {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
encoded := base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
b.WriteString("Proxy-Authorization: Basic ")
b.WriteString(encoded)
b.WriteString("\r\n")
}
b.WriteString("\r\n")
return b.String()
}

View File

@@ -118,6 +118,7 @@ type oauthSession struct {
ProjectID string `json:"project_id,omitempty"`
DeviceID string `json:"device_id,omitempty"`
ResourceURL string `json:"resource_url,omitempty"`
NetworkProxy string `json:"network_proxy,omitempty"`
Scope string `json:"scope,omitempty"`
Token map[string]any `json:"token,omitempty"`
CooldownUntil string `json:"-"`
@@ -143,7 +144,6 @@ type oauthConfig struct {
Scopes []string
RefreshScan time.Duration
RefreshLead time.Duration
HybridPriority string
Cooldown time.Duration
FlowKind string
TokenStyle string
@@ -154,7 +154,10 @@ type oauthConfig struct {
type oauthManager struct {
providerName string
cfg oauthConfig
timeout time.Duration
httpClient *http.Client
clientMu sync.Mutex
clients map[string]*http.Client
mu sync.Mutex
cached []*oauthSession
cooldowns map[string]time.Time
@@ -198,6 +201,7 @@ type OAuthLoginOptions struct {
NoBrowser bool
Reader io.Reader
AccountLabel string
NetworkProxy string
}
type OAuthPendingFlow struct {
@@ -218,6 +222,7 @@ type OAuthSessionInfo struct {
CredentialFile string
ProjectID string
AccountLabel string
NetworkProxy string
}
type OAuthAccountInfo struct {
@@ -230,6 +235,7 @@ type OAuthAccountInfo struct {
AccountLabel string `json:"account_label,omitempty"`
DeviceID string `json:"device_id,omitempty"`
ResourceURL string `json:"resource_url,omitempty"`
NetworkProxy string `json:"network_proxy,omitempty"`
CooldownUntil string `json:"cooldown_until,omitempty"`
FailureCount int `json:"failure_count,omitempty"`
LastFailure string `json:"last_failure,omitempty"`
@@ -277,6 +283,7 @@ func (m *OAuthLoginManager) Login(ctx context.Context, apiBase string, opts OAut
CredentialFile: session.FilePath,
ProjectID: session.ProjectID,
AccountLabel: sessionLabel(session),
NetworkProxy: maskedProxyURL(session.NetworkProxy),
}, models, nil
}
@@ -288,13 +295,20 @@ func (m *OAuthLoginManager) CredentialFile() string {
}
func (m *OAuthLoginManager) StartManualFlow() (*OAuthPendingFlow, error) {
return m.StartManualFlowWithOptions(OAuthLoginOptions{})
}
func (m *OAuthLoginManager) StartManualFlowWithOptions(opts OAuthLoginOptions) (*OAuthPendingFlow, error) {
if m == nil || m.manager == nil {
return nil, fmt.Errorf("oauth login manager not configured")
}
if m.manager.cfg.FlowKind == oauthFlowDevice {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return m.manager.startDeviceFlow(ctx)
return m.manager.startDeviceFlow(ctx, opts)
}
if _, err := normalizeOptionalProxyURL(opts.NetworkProxy); err != nil {
return nil, err
}
pkceVerifier, pkceChallenge, err := generatePKCE()
if err != nil {
@@ -350,6 +364,7 @@ func (m *OAuthLoginManager) CompleteManualFlowWithOptions(ctx context.Context, a
CredentialFile: session.FilePath,
ProjectID: session.ProjectID,
AccountLabel: sessionLabel(session),
NetworkProxy: maskedProxyURL(session.NetworkProxy),
}, models, nil
}
@@ -371,6 +386,7 @@ func (m *OAuthLoginManager) ImportAuthJSONWithOptions(ctx context.Context, apiBa
CredentialFile: session.FilePath,
ProjectID: session.ProjectID,
AccountLabel: sessionLabel(session),
NetworkProxy: maskedProxyURL(session.NetworkProxy),
}, models, nil
}
@@ -399,6 +415,7 @@ func (m *OAuthLoginManager) ListAccounts() ([]OAuthAccountInfo, error) {
AccountLabel: sessionLabel(session),
DeviceID: session.DeviceID,
ResourceURL: session.ResourceURL,
NetworkProxy: maskedProxyURL(session.NetworkProxy),
CooldownUntil: session.CooldownUntil,
FailureCount: session.FailureCount,
LastFailure: session.LastFailure,
@@ -436,6 +453,7 @@ func (m *OAuthLoginManager) RefreshAccount(ctx context.Context, credentialFile s
AccountLabel: sessionLabel(refreshed),
DeviceID: refreshed.DeviceID,
ResourceURL: refreshed.ResourceURL,
NetworkProxy: maskedProxyURL(refreshed.NetworkProxy),
CooldownUntil: refreshed.CooldownUntil,
FailureCount: refreshed.FailureCount,
LastFailure: refreshed.LastFailure,
@@ -506,10 +524,16 @@ func newOAuthManager(pc config.ProviderConfig, timeout time.Duration) (*oauthMan
if err != nil {
return nil, err
}
client, err := newOAuthHTTPClient(resolved.Provider, timeout, "")
if err != nil {
return nil, err
}
bgCtx, bgCancel := context.WithCancel(context.Background())
manager := &oauthManager{
cfg: resolved,
httpClient: newOAuthHTTPClient(resolved.Provider, timeout),
timeout: timeout,
httpClient: client,
clients: map[string]*http.Client{"": client},
cooldowns: map[string]time.Time{},
bgCtx: bgCtx,
bgCancel: bgCancel,
@@ -518,11 +542,32 @@ func newOAuthManager(pc config.ProviderConfig, timeout time.Duration) (*oauthMan
return manager, nil
}
func newOAuthHTTPClient(provider string, timeout time.Duration) *http.Client {
if provider == defaultClaudeOAuthProvider {
return newAnthropicOAuthHTTPClient(timeout)
func newOAuthHTTPClient(provider string, timeout time.Duration, proxyURL string) (*http.Client, error) {
normalizedProxy, err := normalizeOptionalProxyURL(proxyURL)
if err != nil {
return nil, err
}
return &http.Client{Timeout: timeout}
if provider == defaultClaudeOAuthProvider {
return newAnthropicOAuthHTTPClient(timeout, normalizedProxy)
}
if normalizedProxy == "" {
return &http.Client{Timeout: timeout}, nil
}
parsed, err := url.Parse(normalizedProxy)
if err != nil {
return nil, err
}
return &http.Client{
Timeout: timeout,
Transport: &http.Transport{
Proxy: http.ProxyURL(parsed),
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 15 * time.Second,
ExpectContinueTimeout: time.Second,
},
}, nil
}
func resolveOAuthConfig(pc config.ProviderConfig) (oauthConfig, error) {
@@ -545,7 +590,6 @@ func resolveOAuthConfig(pc config.ProviderConfig) (oauthConfig, error) {
TokenURL: strings.TrimSpace(pc.OAuth.TokenURL),
RedirectURL: strings.TrimSpace(pc.OAuth.RedirectURL),
Scopes: trimNonEmptyStrings(pc.OAuth.Scopes),
HybridPriority: normalizeHybridPriority(pc.OAuth.HybridPriority),
Cooldown: durationFromSeconds(pc.OAuth.CooldownSec, 15*time.Minute),
RefreshScan: durationFromSeconds(pc.OAuth.RefreshScanSec, 10*time.Minute),
RefreshLead: defaultRefreshLead(provider, pc.OAuth.RefreshLeadSec),
@@ -626,9 +670,6 @@ func resolveOAuthConfig(pc config.ProviderConfig) (oauthConfig, error) {
cfg.CredentialFiles = uniqueStrings(append([]string{cfg.CredentialFile}, cfg.CredentialFiles...))
cfg.CredentialFile = cfg.CredentialFiles[0]
}
if cfg.HybridPriority == "" {
cfg.HybridPriority = "api_first"
}
return cfg, nil
}
@@ -675,7 +716,12 @@ func (m *oauthManager) models(ctx context.Context, apiBase string) ([]string, er
seen := map[string]struct{}{}
var lastErr error
for _, attempt := range attempts {
models, err := fetchOpenAIModels(ctx, m.httpClient, apiBase, attempt.Token)
client, err := m.httpClientForSession(attempt.Session)
if err != nil {
lastErr = err
continue
}
models, err := fetchOpenAIModels(ctx, client, apiBase, attempt.Token)
if err != nil {
lastErr = err
continue
@@ -704,7 +750,7 @@ func (m *oauthManager) login(ctx context.Context, apiBase string, opts OAuthLogi
return nil, nil, fmt.Errorf("oauth manager not configured")
}
if m.cfg.FlowKind == oauthFlowDevice {
flow, err := m.startDeviceFlow(ctx)
flow, err := m.startDeviceFlow(ctx, opts)
if err != nil {
return nil, nil, err
}
@@ -741,14 +787,18 @@ func (m *oauthManager) login(ctx context.Context, apiBase string, opts OAuthLogi
}
func (m *oauthManager) completeLogin(ctx context.Context, apiBase, pkceVerifier string, callback *oauthCallbackResult, state string, opts OAuthLoginOptions) (*oauthSession, []string, error) {
session, err := m.exchangeCode(ctx, callback.Code, pkceVerifier, state)
session, err := m.exchangeCode(ctx, callback.Code, pkceVerifier, state, opts.NetworkProxy)
if err != nil {
return nil, nil, err
}
if err := m.applyAccountLabel(session, opts); err != nil {
if err := m.applySessionOptions(session, opts); err != nil {
return nil, nil, err
}
models, _ := fetchOpenAIModels(ctx, m.httpClient, apiBase, session.AccessToken)
client, err := m.httpClientForSession(session)
if err != nil {
return nil, nil, err
}
models, _ := fetchOpenAIModels(ctx, client, apiBase, session.AccessToken)
if len(models) > 0 {
session.Models = append([]string(nil), models...)
}
@@ -788,10 +838,14 @@ func (m *oauthManager) importSession(ctx context.Context, apiBase, fileName stri
if err != nil {
return nil, nil, err
}
if err := m.applyAccountLabel(session, opts); err != nil {
if err := m.applySessionOptions(session, opts); err != nil {
return nil, nil, err
}
models, _ := fetchOpenAIModels(ctx, m.httpClient, apiBase, session.AccessToken)
client, err := m.httpClientForSession(session)
if err != nil {
return nil, nil, err
}
models, _ := fetchOpenAIModels(ctx, client, apiBase, session.AccessToken)
if len(models) > 0 {
session.Models = append([]string(nil), models...)
}
@@ -1094,7 +1148,7 @@ func (m *oauthManager) authorizationURL(state, pkceChallenge string) string {
return m.cfg.AuthURL + "?" + v.Encode()
}
func (m *oauthManager) exchangeCode(ctx context.Context, code, verifier, state string) (*oauthSession, error) {
func (m *oauthManager) exchangeCode(ctx context.Context, code, verifier, state string, proxyURL string) (*oauthSession, error) {
switch m.cfg.Provider {
case defaultClaudeOAuthProvider:
reqBody := map[string]any{
@@ -1105,7 +1159,7 @@ func (m *oauthManager) exchangeCode(ctx context.Context, code, verifier, state s
"redirect_uri": m.cfg.RedirectURL,
"code_verifier": verifier,
}
raw, err := m.doJSONTokenRequest(ctx, reqBody)
raw, err := m.doJSONTokenRequest(ctx, reqBody, proxyURL)
if err != nil {
return nil, err
}
@@ -1122,7 +1176,7 @@ func (m *oauthManager) exchangeCode(ctx context.Context, code, verifier, state s
if m.cfg.ClientSecret != "" {
form.Set("client_secret", m.cfg.ClientSecret)
}
raw, err := m.doFormTokenRequest(ctx, form)
raw, err := m.doFormTokenRequest(ctx, form, proxyURL)
if err != nil {
return nil, err
}
@@ -1159,7 +1213,7 @@ func (m *oauthManager) refreshSessionData(ctx context.Context, session *oauthSes
"client_id": m.cfg.ClientID,
"grant_type": "refresh_token",
"refresh_token": session.RefreshToken,
})
}, session.NetworkProxy)
if err != nil {
return nil, err
}
@@ -1185,7 +1239,7 @@ func (m *oauthManager) refreshSessionData(ctx context.Context, session *oauthSes
if len(m.cfg.Scopes) > 0 && m.cfg.Provider != defaultQwenOAuthProvider && m.cfg.Provider != defaultKimiOAuthProvider {
form.Set("scope", strings.Join(m.cfg.Scopes, " "))
}
raw, err := m.doFormTokenRequest(ctx, form)
raw, err := m.doFormTokenRequest(ctx, form, session.NetworkProxy)
if err != nil {
return nil, err
}
@@ -1217,7 +1271,7 @@ func (m *oauthManager) refreshGoogleTokenSession(ctx context.Context, session *o
if clientSecret != "" {
form.Set("client_secret", clientSecret)
}
raw, err := m.doFormTokenRequestURL(ctx, tokenURL, form)
raw, err := m.doFormTokenRequestURL(ctx, tokenURL, form, session.NetworkProxy)
if err != nil {
return nil, err
}
@@ -1254,13 +1308,13 @@ func (m *oauthManager) enrichSession(ctx context.Context, session *oauthSession)
switch m.cfg.Provider {
case defaultAntigravityOAuthProvider, defaultGeminiOAuthProvider:
if strings.TrimSpace(session.Email) == "" && m.cfg.UserInfoURL != "" && session.AccessToken != "" {
email, err := m.fetchUserEmail(ctx, session.AccessToken)
email, err := m.fetchUserEmail(ctx, session.AccessToken, session.NetworkProxy)
if err == nil {
session.Email = email
}
}
if m.cfg.Provider == defaultAntigravityOAuthProvider && strings.TrimSpace(session.ProjectID) == "" && session.AccessToken != "" {
projectID, err := m.fetchAntigravityProjectID(ctx, session.AccessToken)
projectID, err := m.fetchAntigravityProjectID(ctx, session.AccessToken, session.NetworkProxy)
if err == nil {
session.ProjectID = projectID
}
@@ -1413,6 +1467,22 @@ func (m *oauthManager) applyAccountLabel(session *oauthSession, opts OAuthLoginO
return nil
}
func (m *oauthManager) applySessionOptions(session *oauthSession, opts OAuthLoginOptions) error {
if err := m.applyAccountLabel(session, opts); err != nil {
return err
}
if session == nil {
return fmt.Errorf("oauth session is nil")
}
proxyURL := firstNonEmpty(opts.NetworkProxy, session.NetworkProxy)
normalized, err := normalizeOptionalProxyURL(proxyURL)
if err != nil {
return err
}
session.NetworkProxy = normalized
return nil
}
func sessionLabel(session *oauthSession) string {
if session == nil {
return ""
@@ -1420,6 +1490,34 @@ func sessionLabel(session *oauthSession) string {
return firstNonEmpty(session.Email, session.AccountID, session.ProjectID)
}
func (m *oauthManager) httpClientForSession(session *oauthSession) (*http.Client, error) {
if session == nil {
return m.httpClient, nil
}
return m.httpClientForProxy(session.NetworkProxy)
}
func (m *oauthManager) httpClientForProxy(proxyURL string) (*http.Client, error) {
normalized, err := normalizeOptionalProxyURL(proxyURL)
if err != nil {
return nil, err
}
if normalized == "" {
return m.httpClient, nil
}
m.clientMu.Lock()
defer m.clientMu.Unlock()
if client, ok := m.clients[normalized]; ok && client != nil {
return client, nil
}
client, err := newOAuthHTTPClient(m.cfg.Provider, m.timeout, normalized)
if err != nil {
return nil, err
}
m.clients[normalized] = client
return client, nil
}
func (m *oauthManager) allocateCredentialPathLocked(session *oauthSession) (string, error) {
files := m.credentialFiles()
if len(files) == 1 {
@@ -1517,6 +1615,7 @@ func mergeOAuthSession(prev, next *oauthSession) *oauthSession {
merged.ProjectID = firstNonEmpty(next.ProjectID, prev.ProjectID)
merged.DeviceID = firstNonEmpty(next.DeviceID, prev.DeviceID)
merged.ResourceURL = firstNonEmpty(next.ResourceURL, prev.ResourceURL)
merged.NetworkProxy = firstNonEmpty(next.NetworkProxy, prev.NetworkProxy)
merged.Scope = firstNonEmpty(next.Scope, prev.Scope)
merged.Models = append([]string(nil), prev.Models...)
if len(next.Models) > 0 {
@@ -1828,17 +1927,6 @@ func durationFromSeconds(value int, fallback time.Duration) time.Duration {
return time.Duration(value) * time.Second
}
func normalizeHybridPriority(value string) string {
switch strings.ToLower(strings.TrimSpace(value)) {
case "", "api_first":
return "api_first"
case "oauth_first":
return "oauth_first"
default:
return ""
}
}
func parseOAuthCallbackURL(raw string) (*oauthCallbackResult, error) {
parsed, err := url.Parse(strings.TrimSpace(raw))
if err != nil {
@@ -1902,6 +1990,7 @@ func parseImportedOAuthSession(provider, fileName string, data []byte) (*oauthSe
session.ProjectID = firstNonEmpty(asString(raw["project_id"]), asString(raw["projectId"]))
session.DeviceID = firstNonEmpty(asString(raw["device_id"]), asString(raw["deviceId"]))
session.ResourceURL = asString(raw["resource_url"])
session.NetworkProxy = firstNonEmpty(asString(raw["network_proxy"]), asString(raw["proxy_url"]), asString(raw["http_proxy"]))
session.Scope = firstNonEmpty(asString(raw["scope"]), asString(raw["scopes"]))
if models := stringSliceFromAny(raw["models"]); len(models) > 0 {
session.Models = models
@@ -1916,6 +2005,7 @@ func parseImportedOAuthSession(provider, fileName string, data []byte) (*oauthSe
)
session.ProjectID = firstNonEmpty(session.ProjectID, asString(session.Token["project_id"]), asString(session.Token["projectId"]))
session.DeviceID = firstNonEmpty(session.DeviceID, asString(session.Token["device_id"]), asString(session.Token["deviceId"]))
session.NetworkProxy = firstNonEmpty(session.NetworkProxy, asString(session.Token["network_proxy"]), asString(session.Token["proxy_url"]), asString(session.Token["http_proxy"]))
session.Scope = firstNonEmpty(session.Scope, asString(session.Token["scope"]), asString(session.Token["scopes"]))
}
if claims := parseJWTClaims(session.IDToken); len(claims) > 0 {
@@ -2043,21 +2133,21 @@ func defaultInt(value, fallback int) int {
return fallback
}
func (m *oauthManager) doFormTokenRequest(ctx context.Context, form url.Values) (map[string]any, error) {
return m.doFormTokenRequestURL(ctx, m.cfg.TokenURL, form)
func (m *oauthManager) doFormTokenRequest(ctx context.Context, form url.Values, proxyURL string) (map[string]any, error) {
return m.doFormTokenRequestURL(ctx, m.cfg.TokenURL, form, proxyURL)
}
func (m *oauthManager) doFormTokenRequestURL(ctx context.Context, endpoint string, form url.Values) (map[string]any, error) {
func (m *oauthManager) doFormTokenRequestURL(ctx context.Context, endpoint string, form url.Values, proxyURL string) (map[string]any, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
return m.doJSONRequest(req, "oauth token request")
return m.doJSONRequest(req, "oauth token request", proxyURL)
}
func (m *oauthManager) doJSONTokenRequest(ctx context.Context, payload map[string]any) (map[string]any, error) {
func (m *oauthManager) doJSONTokenRequest(ctx context.Context, payload map[string]any, proxyURL string) (map[string]any, error) {
body, err := json.Marshal(payload)
if err != nil {
return nil, err
@@ -2068,11 +2158,15 @@ func (m *oauthManager) doJSONTokenRequest(ctx context.Context, payload map[strin
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
return m.doJSONRequest(req, "oauth token request")
return m.doJSONRequest(req, "oauth token request", proxyURL)
}
func (m *oauthManager) doJSONRequest(req *http.Request, label string) (map[string]any, error) {
resp, err := m.httpClient.Do(req)
func (m *oauthManager) doJSONRequest(req *http.Request, label, proxyURL string) (map[string]any, error) {
client, err := m.httpClientForProxy(proxyURL)
if err != nil {
return nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("%s failed: %w", label, err)
}
@@ -2091,13 +2185,13 @@ func (m *oauthManager) doJSONRequest(req *http.Request, label string) (map[strin
return raw, nil
}
func (m *oauthManager) fetchUserEmail(ctx context.Context, token string) (string, error) {
func (m *oauthManager) fetchUserEmail(ctx context.Context, token, proxyURL string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, m.cfg.UserInfoURL, nil)
if err != nil {
return "", err
}
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
raw, err := m.doJSONRequest(req, "oauth userinfo request")
raw, err := m.doJSONRequest(req, "oauth userinfo request", proxyURL)
if err != nil {
return "", err
}
@@ -2108,7 +2202,7 @@ func (m *oauthManager) fetchUserEmail(ctx context.Context, token string) (string
return email, nil
}
func (m *oauthManager) fetchAntigravityProjectID(ctx context.Context, token string) (string, error) {
func (m *oauthManager) fetchAntigravityProjectID(ctx context.Context, token, proxyURL string) (string, error) {
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", defaultAntigravityAPIEndpoint, defaultAntigravityAPIVersion)
body := `{"metadata":{"ideType":"ANTIGRAVITY","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}}`
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(body))
@@ -2120,7 +2214,11 @@ func (m *oauthManager) fetchAntigravityProjectID(ctx context.Context, token stri
req.Header.Set("User-Agent", defaultAntigravityAPIUserAgent)
req.Header.Set("X-Goog-Api-Client", defaultAntigravityAPIClient)
req.Header.Set("Client-Metadata", defaultAntigravityClientMeta)
resp, err := m.httpClient.Do(req)
client, err := m.httpClientForProxy(proxyURL)
if err != nil {
return "", err
}
resp, err := client.Do(req)
if err != nil {
return "", err
}
@@ -2148,7 +2246,7 @@ func (m *oauthManager) fetchAntigravityProjectID(ctx context.Context, token stri
return projectID, nil
}
func (m *oauthManager) startDeviceFlow(ctx context.Context) (*OAuthPendingFlow, error) {
func (m *oauthManager) startDeviceFlow(ctx context.Context, opts OAuthLoginOptions) (*OAuthPendingFlow, error) {
if m.cfg.FlowKind != oauthFlowDevice {
return nil, fmt.Errorf("oauth provider %s does not use device flow", m.cfg.Provider)
}
@@ -2165,7 +2263,7 @@ func (m *oauthManager) startDeviceFlow(ctx context.Context) (*OAuthPendingFlow,
}
form.Set("code_challenge", challenge)
form.Set("code_challenge_method", "S256")
raw, err := m.doFormDeviceRequest(ctx, m.cfg.DeviceCodeURL, form)
raw, err := m.doFormDeviceRequest(ctx, m.cfg.DeviceCodeURL, form, opts.NetworkProxy)
if err != nil {
return nil, err
}
@@ -2184,7 +2282,7 @@ func (m *oauthManager) startDeviceFlow(ctx context.Context) (*OAuthPendingFlow,
Instructions: "Open the verification URL, finish authorization, then click continue to let the gateway poll for tokens.",
}, nil
case defaultKimiOAuthProvider:
raw, err := m.doFormDeviceRequest(ctx, m.cfg.DeviceCodeURL, form)
raw, err := m.doFormDeviceRequest(ctx, m.cfg.DeviceCodeURL, form, opts.NetworkProxy)
if err != nil {
return nil, err
}
@@ -2206,7 +2304,7 @@ func (m *oauthManager) startDeviceFlow(ctx context.Context) (*OAuthPendingFlow,
}
}
func (m *oauthManager) doFormDeviceRequest(ctx context.Context, endpoint string, form url.Values) (map[string]any, error) {
func (m *oauthManager) doFormDeviceRequest(ctx context.Context, endpoint string, form url.Values, proxyURL string) (map[string]any, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
if err != nil {
return nil, err
@@ -2220,7 +2318,7 @@ func (m *oauthManager) doFormDeviceRequest(ctx context.Context, endpoint string,
req.Header.Set("X-Msh-Device-Model", runtime.GOOS+" "+runtime.GOARCH)
req.Header.Set("X-Msh-Device-Id", randomDeviceID())
}
return m.doJSONRequest(req, "oauth device request")
return m.doJSONRequest(req, "oauth device request", proxyURL)
}
func parseDeviceFlowPayload(raw map[string]any) (*oauthDeviceCodeResponse, error) {
@@ -2254,14 +2352,18 @@ func randomDeviceID() string {
}
func (m *oauthManager) completeDeviceFlow(ctx context.Context, apiBase string, flow *OAuthPendingFlow, opts OAuthLoginOptions) (*oauthSession, []string, error) {
session, err := m.pollDeviceToken(ctx, flow)
session, err := m.pollDeviceToken(ctx, flow, opts.NetworkProxy)
if err != nil {
return nil, nil, err
}
if err := m.applyAccountLabel(session, opts); err != nil {
if err := m.applySessionOptions(session, opts); err != nil {
return nil, nil, err
}
models, _ := fetchOpenAIModels(ctx, m.httpClient, apiBase, session.AccessToken)
client, err := m.httpClientForSession(session)
if err != nil {
return nil, nil, err
}
models, _ := fetchOpenAIModels(ctx, client, apiBase, session.AccessToken)
if len(models) > 0 {
session.Models = append([]string(nil), models...)
}
@@ -2281,7 +2383,7 @@ func (m *oauthManager) completeDeviceFlow(ctx context.Context, apiBase string, f
return session, models, nil
}
func (m *oauthManager) pollDeviceToken(ctx context.Context, flow *OAuthPendingFlow) (*oauthSession, error) {
func (m *oauthManager) pollDeviceToken(ctx context.Context, flow *OAuthPendingFlow, proxyURL string) (*oauthSession, error) {
if flow == nil || strings.TrimSpace(flow.DeviceCode) == "" {
return nil, fmt.Errorf("oauth device flow missing device code")
}
@@ -2306,7 +2408,7 @@ func (m *oauthManager) pollDeviceToken(ctx context.Context, flow *OAuthPendingFl
if flow.PKCEVerifier != "" {
form.Set("code_verifier", flow.PKCEVerifier)
}
raw, err := m.doFormTokenRequest(ctx, form)
raw, err := m.doFormTokenRequest(ctx, form, proxyURL)
if err == nil {
session, convErr := sessionFromTokenPayload(m.cfg.Provider, raw)
if convErr != nil {

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
@@ -386,6 +387,118 @@ func TestResolveOAuthConfigAppliesProviderRefreshLeadDefaults(t *testing.T) {
}
}
func TestHTTPProviderOAuthSessionProxyRoutesRefreshAndResponses(t *testing.T) {
t.Parallel()
var refreshCalls int32
var responseCalls int32
target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/oauth/token":
atomic.AddInt32(&refreshCalls, 1)
if err := r.ParseForm(); err != nil {
t.Fatalf("parse token form failed: %v", err)
}
if got := r.Form.Get("grant_type"); got != "refresh_token" {
t.Fatalf("unexpected grant_type: %s", got)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"access_token":"proxied-fresh-token","refresh_token":"refresh-token","expires_in":3600}`))
case "/v1/responses":
atomic.AddInt32(&responseCalls, 1)
if got := r.Header.Get("Authorization"); got != "Bearer proxied-fresh-token" {
t.Fatalf("unexpected authorization header: %s", got)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"status":"completed","output_text":"ok-via-proxy"}`))
default:
http.NotFound(w, r)
}
}))
defer target.Close()
var proxyCalls int32
proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&proxyCalls, 1)
targetURL := r.URL.String()
if !strings.HasPrefix(targetURL, "http://") && !strings.HasPrefix(targetURL, "https://") {
targetURL = target.URL + r.URL.Path
if rawQuery := strings.TrimSpace(r.URL.RawQuery); rawQuery != "" {
targetURL += "?" + rawQuery
}
}
req, err := http.NewRequestWithContext(r.Context(), r.Method, targetURL, r.Body)
if err != nil {
t.Fatalf("create proxied request failed: %v", err)
}
req.Header = r.Header.Clone()
resp, err := http.DefaultTransport.RoundTrip(req)
if err != nil {
t.Fatalf("proxy round trip failed: %v", err)
}
defer resp.Body.Close()
for key, values := range resp.Header {
for _, value := range values {
w.Header().Add(key, value)
}
}
w.WriteHeader(resp.StatusCode)
_, _ = io.Copy(w, resp.Body)
}))
defer proxyServer.Close()
credFile := filepath.Join(t.TempDir(), "proxied.json")
raw, err := json.Marshal(oauthSession{
Provider: "codex",
AccessToken: "expired-token",
RefreshToken: "refresh-token",
Expire: time.Now().Add(-time.Hour).Format(time.RFC3339),
NetworkProxy: proxyServer.URL,
})
if err != nil {
t.Fatalf("marshal session failed: %v", err)
}
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
t.Fatalf("write credential file failed: %v", err)
}
pc := config.ProviderConfig{
APIBase: target.URL + "/v1",
Auth: "oauth",
TimeoutSec: 5,
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: credFile,
ClientID: "test-client",
TokenURL: target.URL + "/oauth/token",
AuthURL: target.URL + "/oauth/authorize",
},
}
oauth, err := newOAuthManager(pc, 5*time.Second)
if err != nil {
t.Fatalf("new oauth manager failed: %v", err)
}
defer oauth.bgCancel()
provider := NewHTTPProvider("test-oauth", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
if err != nil {
t.Fatalf("chat failed: %v", err)
}
if resp.Content != "ok-via-proxy" {
t.Fatalf("unexpected response content: %q", resp.Content)
}
if atomic.LoadInt32(&refreshCalls) != 1 {
t.Fatalf("expected one refresh call, got %d", refreshCalls)
}
if atomic.LoadInt32(&responseCalls) != 1 {
t.Fatalf("expected one response call, got %d", responseCalls)
}
if got := atomic.LoadInt32(&proxyCalls); got < 2 {
t.Fatalf("expected proxy to receive refresh and response requests, got %d", got)
}
}
func TestOAuthImportGeminiNestedTokenRefreshesWithTokenMetadata(t *testing.T) {
t.Parallel()
@@ -547,7 +660,7 @@ func TestQwenDeviceFlowRequiresAccountLabelWhenEmailMissing(t *testing.T) {
}
defer manager.bgCancel()
flow, err := manager.startDeviceFlow(context.Background())
flow, err := manager.startDeviceFlow(context.Background(), OAuthLoginOptions{})
if err != nil {
t.Fatalf("start device flow failed: %v", err)
}
@@ -734,7 +847,7 @@ func TestOAuthDeviceFlowQwenManualCompletes(t *testing.T) {
t.Fatalf("new oauth manager failed: %v", err)
}
flow, err := manager.startDeviceFlow(context.Background())
flow, err := manager.startDeviceFlow(context.Background(), OAuthLoginOptions{})
if err != nil {
t.Fatalf("start device flow failed: %v", err)
}
@@ -879,7 +992,6 @@ func TestHTTPProviderHybridOAuthFirstUsesOAuthBeforeAPIKey(t *testing.T) {
CredentialFile: credFile,
TokenURL: server.URL + "/oauth/token",
AuthURL: server.URL + "/oauth/authorize",
HybridPriority: "oauth_first",
},
}
oauth, err := newOAuthManager(pc, 5*time.Second)
@@ -891,11 +1003,11 @@ func TestHTTPProviderHybridOAuthFirstUsesOAuthBeforeAPIKey(t *testing.T) {
if err != nil {
t.Fatalf("chat failed: %v", err)
}
if resp.Content != "ok-from-oauth" {
if resp.Content != "ok-from-api" {
t.Fatalf("unexpected response content: %q", resp.Content)
}
if atomic.LoadInt32(&oauthCalls) != 1 || atomic.LoadInt32(&apiKeyCalls) != 0 {
t.Fatalf("expected oauth first only, got api=%d oauth=%d", apiKeyCalls, oauthCalls)
if atomic.LoadInt32(&apiKeyCalls) != 1 || atomic.LoadInt32(&oauthCalls) != 0 {
t.Fatalf("expected api key first only, got api=%d oauth=%d", apiKeyCalls, oauthCalls)
}
}
@@ -1187,7 +1299,6 @@ func TestProviderRuntimeSnapshotIncludesCandidateOrderAndLastSuccess(t *testing.
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: credFile,
HybridPriority: "api_first",
},
}
ConfigureProviderRuntime(name, pc)
@@ -1206,8 +1317,8 @@ func TestProviderRuntimeSnapshotIncludesCandidateOrderAndLastSuccess(t *testing.
provider.markAttemptSuccess(attempts[1])
cfg := &config.Config{
Providers: config.ProvidersConfig{
Proxies: map[string]config.ProviderConfig{name: pc},
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{name: pc},
},
}
snapshot := GetProviderRuntimeSnapshot(cfg)
@@ -1267,8 +1378,8 @@ func TestConfigureProviderRuntimeLoadsPersistedEvents(t *testing.T) {
})
cfg := &config.Config{
Providers: config.ProvidersConfig{
Proxies: map[string]config.ProviderConfig{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
name: {
APIBase: "https://example.com/v1",
Auth: "bearer",
@@ -1357,7 +1468,6 @@ func TestUpdateCandidateOrderRecordsSchedulerChange(t *testing.T) {
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: credFile,
HybridPriority: "api_first",
},
}
manager, err := newOAuthManager(pc, 5*time.Second)
@@ -1412,8 +1522,8 @@ func TestGetProviderRuntimeViewFiltersEvents(t *testing.T) {
providerRuntimeRegistry.mu.Unlock()
cfg := &config.Config{
Providers: config.ProvidersConfig{
Proxies: map[string]config.ProviderConfig{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
name: {APIBase: "https://example.com/v1", Auth: "hybrid", APIKey: "api-key"},
},
},
@@ -1463,8 +1573,8 @@ func TestGetProviderRuntimeViewCursorPagination(t *testing.T) {
}
providerRuntimeRegistry.mu.Unlock()
cfg := &config.Config{
Providers: config.ProvidersConfig{
Proxies: map[string]config.ProviderConfig{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
name: {APIBase: "https://example.com/v1", Auth: "oauth", OAuth: config.ProviderOAuthConfig{Provider: "codex"}},
},
},
@@ -1503,8 +1613,8 @@ func TestGetProviderRuntimeViewSortAscending(t *testing.T) {
}
providerRuntimeRegistry.mu.Unlock()
cfg := &config.Config{
Providers: config.ProvidersConfig{
Proxies: map[string]config.ProviderConfig{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
name: {APIBase: "https://example.com/v1", Auth: "oauth", OAuth: config.ProviderOAuthConfig{Provider: "codex"}},
},
},
@@ -1542,8 +1652,8 @@ func TestGetProviderRuntimeViewFiltersByHealthAndCooldown(t *testing.T) {
}
providerRuntimeRegistry.mu.Unlock()
cfg := &config.Config{
Providers: config.ProvidersConfig{
Proxies: map[string]config.ProviderConfig{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
name: {APIBase: "https://example.com/v1", Auth: "bearer", APIKey: "api-key"},
},
},
@@ -1594,8 +1704,8 @@ func TestGetProviderRuntimeSummaryFlagsUnhealthyProviders(t *testing.T) {
}
providerRuntimeRegistry.mu.Unlock()
cfg := &config.Config{
Providers: config.ProvidersConfig{
Proxies: map[string]config.ProviderConfig{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
name: {APIBase: "https://example.com/v1", Auth: "bearer", APIKey: "api-key"},
},
},
@@ -1643,8 +1753,8 @@ func TestGetProviderRuntimeSummaryMarksRecentErrorsAsDegraded(t *testing.T) {
}
providerRuntimeRegistry.mu.Unlock()
cfg := &config.Config{
Providers: config.ProvidersConfig{
Proxies: map[string]config.ProviderConfig{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
name: {APIBase: "https://example.com/v1", Auth: "oauth", OAuth: config.ProviderOAuthConfig{Provider: "codex"}},
},
},
@@ -1688,8 +1798,8 @@ func TestGetProviderRuntimeSummaryIncludesOAuthAccountMetadata(t *testing.T) {
t.Fatalf("write session failed: %v", err)
}
cfg := &config.Config{
Providers: config.ProvidersConfig{
Proxies: map[string]config.ProviderConfig{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
"qwen-summary": {
APIBase: "https://example.com/v1",
Auth: "oauth",
@@ -1746,8 +1856,8 @@ func TestRefreshProviderRuntimeNowSupportsOnlyExpiring(t *testing.T) {
name := "runtime-refresh-provider"
cfg := &config.Config{
Providers: config.ProvidersConfig{
Proxies: map[string]config.ProviderConfig{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
name: {
APIBase: server.URL + "/v1",
Auth: "oauth",
@@ -1807,8 +1917,8 @@ func TestRerankProviderRuntimeUpdatesCandidateOrder(t *testing.T) {
}
name := "rerank-runtime-provider"
cfg := &config.Config{
Providers: config.ProvidersConfig{
Proxies: map[string]config.ProviderConfig{
Models: config.ModelsConfig{
Providers: map[string]config.ProviderConfig{
name: {
APIKey: "api-key",
APIBase: "https://example.com/v1",
@@ -1817,7 +1927,6 @@ func TestRerankProviderRuntimeUpdatesCandidateOrder(t *testing.T) {
OAuth: config.ProviderOAuthConfig{
Provider: "codex",
CredentialFile: credFile,
HybridPriority: "oauth_first",
},
},
},
@@ -1827,8 +1936,8 @@ func TestRerankProviderRuntimeUpdatesCandidateOrder(t *testing.T) {
if err != nil {
t.Fatalf("rerank provider runtime failed: %v", err)
}
if len(order) == 0 || order[0].Kind != "oauth" {
t.Fatalf("expected oauth-first rerank result, got %#v", order)
if len(order) == 0 || order[0].Kind != "api_key" {
t.Fatalf("expected api-key-first rerank result, got %#v", order)
}
snapshot := GetProviderRuntimeSnapshot(cfg)
items, _ := snapshot["items"].([]map[string]interface{})
@@ -1836,7 +1945,7 @@ func TestRerankProviderRuntimeUpdatesCandidateOrder(t *testing.T) {
t.Fatalf("expected one runtime item, got %#v", snapshot)
}
snapshotOrder, _ := items[0]["candidate_order"].([]providerRuntimeCandidate)
if len(snapshotOrder) == 0 || snapshotOrder[0].Kind != "oauth" {
t.Fatalf("expected oauth-first candidate order, got %#v", items[0]["candidate_order"])
if len(snapshotOrder) == 0 || snapshotOrder[0].Kind != "api_key" {
t.Fatalf("expected api-key-first candidate order, got %#v", items[0]["candidate_order"])
}
}