mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-06-23 18:20:34 +08:00
release: v0.2.0
This commit is contained in:
@@ -342,22 +342,20 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
subagentDigestDelay: 5 * time.Second,
|
||||
subagentDigests: map[string]*subagentDigestState{},
|
||||
}
|
||||
if _, primaryModel := config.ParseProviderModelRef(cfg.Agents.Defaults.Model.Primary); strings.TrimSpace(primaryModel) != "" {
|
||||
loop.model = strings.TrimSpace(primaryModel)
|
||||
}
|
||||
go loop.runSubagentDigestTicker()
|
||||
// Initialize provider fallback chain (primary + proxy_fallbacks).
|
||||
// Initialize provider fallback chain (primary + model fallbacks).
|
||||
loop.providerPool = map[string]providers.LLMProvider{}
|
||||
loop.providerNames = []string{}
|
||||
primaryName := cfg.Agents.Defaults.Proxy
|
||||
if primaryName == "" {
|
||||
primaryName = "proxy"
|
||||
}
|
||||
primaryName := config.PrimaryProviderName(cfg)
|
||||
loop.providerPool[primaryName] = provider
|
||||
loop.providerNames = append(loop.providerNames, primaryName)
|
||||
if strings.TrimSpace(primaryName) == "proxy" {
|
||||
loop.providerResponses[primaryName] = cfg.Providers.Proxy.Responses
|
||||
} else if pc, ok := cfg.Providers.Proxies[primaryName]; ok {
|
||||
if pc, ok := config.ProviderConfigByName(cfg, primaryName); ok {
|
||||
loop.providerResponses[primaryName] = pc.Responses
|
||||
}
|
||||
for _, name := range cfg.Agents.Defaults.ProxyFallbacks {
|
||||
for _, name := range cfg.Agents.Defaults.Model.Fallbacks {
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
@@ -371,13 +369,13 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
if dup {
|
||||
continue
|
||||
}
|
||||
if p2, err := providers.CreateProviderByName(cfg, name); err == nil {
|
||||
loop.providerPool[name] = p2
|
||||
loop.providerNames = append(loop.providerNames, name)
|
||||
if pc, ok := cfg.Providers.Proxies[name]; ok {
|
||||
loop.providerResponses[name] = pc.Responses
|
||||
if p2, err := providers.CreateProviderByName(cfg, name); err == nil {
|
||||
loop.providerPool[name] = p2
|
||||
loop.providerNames = append(loop.providerNames, name)
|
||||
if pc, ok := config.ProviderConfigByName(cfg, name); ok {
|
||||
loop.providerResponses[name] = pc.Responses
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Inject recursive run logic so subagents can use full tool-calling flows.
|
||||
|
||||
@@ -914,8 +914,8 @@ func collectRiskyConfigPaths(oldMap, newMap map[string]interface{}) []string {
|
||||
"channels.telegram.token",
|
||||
"channels.telegram.allow_from",
|
||||
"channels.telegram.allow_chats",
|
||||
"providers.proxy.api_base",
|
||||
"providers.proxy.api_key",
|
||||
"models.providers.openai.api_base",
|
||||
"models.providers.openai.api_key",
|
||||
"gateway.token",
|
||||
"gateway.port",
|
||||
}
|
||||
@@ -923,9 +923,9 @@ func collectRiskyConfigPaths(oldMap, newMap map[string]interface{}) []string {
|
||||
for _, path := range paths {
|
||||
seen[path] = true
|
||||
}
|
||||
for _, name := range collectProviderProxyNames(oldMap, newMap) {
|
||||
for _, name := range collectProviderNames(oldMap, newMap) {
|
||||
for _, field := range []string{"api_base", "api_key"} {
|
||||
path := "providers.proxies." + name + "." + field
|
||||
path := "models.providers." + name + "." + field
|
||||
if !seen[path] {
|
||||
paths = append(paths, path)
|
||||
seen[path] = true
|
||||
@@ -935,13 +935,13 @@ func collectRiskyConfigPaths(oldMap, newMap map[string]interface{}) []string {
|
||||
return paths
|
||||
}
|
||||
|
||||
func collectProviderProxyNames(maps ...map[string]interface{}) []string {
|
||||
func collectProviderNames(maps ...map[string]interface{}) []string {
|
||||
seen := map[string]bool{}
|
||||
names := make([]string, 0)
|
||||
for _, root := range maps {
|
||||
providers, _ := root["providers"].(map[string]interface{})
|
||||
proxies, _ := providers["proxies"].(map[string]interface{})
|
||||
for name := range proxies {
|
||||
models, _ := root["models"].(map[string]interface{})
|
||||
providers, _ := models["providers"].(map[string]interface{})
|
||||
for name := range providers {
|
||||
if strings.TrimSpace(name) == "" || seen[name] {
|
||||
continue
|
||||
}
|
||||
@@ -1001,6 +1001,7 @@ func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Re
|
||||
var body struct {
|
||||
Provider string `json:"provider"`
|
||||
AccountLabel string `json:"account_label"`
|
||||
NetworkProxy string `json:"network_proxy"`
|
||||
ProviderConfig cfgpkg.ProviderConfig `json:"provider_config"`
|
||||
}
|
||||
if r.Method == http.MethodPost {
|
||||
@@ -1011,6 +1012,7 @@ func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Re
|
||||
} else {
|
||||
body.Provider = strings.TrimSpace(r.URL.Query().Get("provider"))
|
||||
body.AccountLabel = strings.TrimSpace(r.URL.Query().Get("account_label"))
|
||||
body.NetworkProxy = strings.TrimSpace(r.URL.Query().Get("network_proxy"))
|
||||
}
|
||||
cfg, pc, err := s.resolveProviderConfig(strings.TrimSpace(body.Provider), body.ProviderConfig)
|
||||
if err != nil {
|
||||
@@ -1027,7 +1029,10 @@ func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Re
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
flow, err := loginMgr.StartManualFlow()
|
||||
flow, err := loginMgr.StartManualFlowWithOptions(providers.OAuthLoginOptions{
|
||||
AccountLabel: body.AccountLabel,
|
||||
NetworkProxy: firstNonEmptyString(strings.TrimSpace(body.NetworkProxy), strings.TrimSpace(pc.OAuth.NetworkProxy)),
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
@@ -1044,6 +1049,7 @@ func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Re
|
||||
"user_code": flow.UserCode,
|
||||
"instructions": flow.Instructions,
|
||||
"account_label": strings.TrimSpace(body.AccountLabel),
|
||||
"network_proxy": strings.TrimSpace(body.NetworkProxy),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1061,6 +1067,7 @@ func (s *Server) handleWebUIProviderOAuthComplete(w http.ResponseWriter, r *http
|
||||
FlowID string `json:"flow_id"`
|
||||
CallbackURL string `json:"callback_url"`
|
||||
AccountLabel string `json:"account_label"`
|
||||
NetworkProxy string `json:"network_proxy"`
|
||||
ProviderConfig cfgpkg.ProviderConfig `json:"provider_config"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
@@ -1091,6 +1098,7 @@ func (s *Server) handleWebUIProviderOAuthComplete(w http.ResponseWriter, r *http
|
||||
}
|
||||
session, models, err := loginMgr.CompleteManualFlowWithOptions(r.Context(), pc.APIBase, flow, body.CallbackURL, providers.OAuthLoginOptions{
|
||||
AccountLabel: strings.TrimSpace(body.AccountLabel),
|
||||
NetworkProxy: firstNonEmptyString(strings.TrimSpace(body.NetworkProxy), strings.TrimSpace(pc.OAuth.NetworkProxy)),
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
@@ -1111,6 +1119,7 @@ func (s *Server) handleWebUIProviderOAuthComplete(w http.ResponseWriter, r *http
|
||||
"ok": true,
|
||||
"account": session.Email,
|
||||
"credential_file": session.CredentialFile,
|
||||
"network_proxy": session.NetworkProxy,
|
||||
"models": models,
|
||||
})
|
||||
}
|
||||
@@ -1130,6 +1139,7 @@ func (s *Server) handleWebUIProviderOAuthImport(w http.ResponseWriter, r *http.R
|
||||
}
|
||||
providerName := strings.TrimSpace(r.FormValue("provider"))
|
||||
accountLabel := strings.TrimSpace(r.FormValue("account_label"))
|
||||
networkProxy := strings.TrimSpace(r.FormValue("network_proxy"))
|
||||
inlineCfgRaw := strings.TrimSpace(r.FormValue("provider_config"))
|
||||
var inlineCfg cfgpkg.ProviderConfig
|
||||
if inlineCfgRaw != "" {
|
||||
@@ -1165,6 +1175,7 @@ func (s *Server) handleWebUIProviderOAuthImport(w http.ResponseWriter, r *http.R
|
||||
}
|
||||
session, models, err := loginMgr.ImportAuthJSONWithOptions(r.Context(), pc.APIBase, header.Filename, data, providers.OAuthLoginOptions{
|
||||
AccountLabel: accountLabel,
|
||||
NetworkProxy: firstNonEmptyString(networkProxy, strings.TrimSpace(pc.OAuth.NetworkProxy)),
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
@@ -1185,6 +1196,7 @@ func (s *Server) handleWebUIProviderOAuthImport(w http.ResponseWriter, r *http.R
|
||||
"ok": true,
|
||||
"account": session.Email,
|
||||
"credential_file": session.CredentialFile,
|
||||
"network_proxy": session.NetworkProxy,
|
||||
"models": models,
|
||||
})
|
||||
}
|
||||
@@ -1401,10 +1413,7 @@ func (s *Server) loadProviderConfig(name string) (*cfgpkg.Config, cfgpkg.Provide
|
||||
return nil, cfgpkg.ProviderConfig{}, err
|
||||
}
|
||||
providerName := strings.TrimSpace(name)
|
||||
if providerName == "" || providerName == "proxy" {
|
||||
return cfg, cfg.Providers.Proxy, nil
|
||||
}
|
||||
pc, ok := cfg.Providers.Proxies[providerName]
|
||||
pc, ok := cfg.Models.Providers[providerName]
|
||||
if !ok {
|
||||
return nil, cfgpkg.ProviderConfig{}, fmt.Errorf("provider %q not found", providerName)
|
||||
}
|
||||
@@ -1435,14 +1444,10 @@ func (s *Server) saveProviderConfig(cfg *cfgpkg.Config, name string, pc cfgpkg.P
|
||||
return fmt.Errorf("config is nil")
|
||||
}
|
||||
providerName := strings.TrimSpace(name)
|
||||
if providerName == "" || providerName == "proxy" {
|
||||
cfg.Providers.Proxy = pc
|
||||
} else {
|
||||
if cfg.Providers.Proxies == nil {
|
||||
cfg.Providers.Proxies = map[string]cfgpkg.ProviderConfig{}
|
||||
}
|
||||
cfg.Providers.Proxies[providerName] = pc
|
||||
if cfg.Models.Providers == nil {
|
||||
cfg.Models.Providers = map[string]cfgpkg.ProviderConfig{}
|
||||
}
|
||||
cfg.Models.Providers[providerName] = pc
|
||||
if err := cfgpkg.SaveConfig(s.configPath, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -6052,7 +6057,7 @@ func hotReloadFieldInfo() []map[string]interface{} {
|
||||
{"path": "logging.*", "name": "Logging", "description": "Log level, persistence, and related settings"},
|
||||
{"path": "sentinel.*", "name": "Sentinel", "description": "Health checks and auto-heal behavior"},
|
||||
{"path": "agents.*", "name": "Agent", "description": "Models, policies, and default behavior"},
|
||||
{"path": "providers.*", "name": "Providers", "description": "LLM providers and proxy settings"},
|
||||
{"path": "models.providers.*", "name": "Providers", "description": "LLM provider registry and auth settings"},
|
||||
{"path": "tools.*", "name": "Tools", "description": "Tool toggles and runtime options"},
|
||||
{"path": "channels.*", "name": "Channels", "description": "Telegram and other channel settings"},
|
||||
{"path": "cron.*", "name": "Cron", "description": "Global cron runtime settings"},
|
||||
|
||||
@@ -248,16 +248,20 @@ func TestHandleWebUIConfigRequiresConfirmForProviderAPIBaseChange(t *testing.T)
|
||||
|
||||
cfg := cfgpkg.DefaultConfig()
|
||||
cfg.Logging.Enabled = false
|
||||
cfg.Providers.Proxy.APIBase = "https://old.example/v1"
|
||||
cfg.Providers.Proxy.APIKey = "test-key"
|
||||
pc := cfg.Models.Providers["openai"]
|
||||
pc.APIBase = "https://old.example/v1"
|
||||
pc.APIKey = "test-key"
|
||||
cfg.Models.Providers["openai"] = pc
|
||||
if err := cfgpkg.SaveConfig(cfgPath, cfg); err != nil {
|
||||
t.Fatalf("save config: %v", err)
|
||||
}
|
||||
|
||||
bodyCfg := cfgpkg.DefaultConfig()
|
||||
bodyCfg.Logging.Enabled = false
|
||||
bodyCfg.Providers.Proxy.APIBase = "https://new.example/v1"
|
||||
bodyCfg.Providers.Proxy.APIKey = "test-key"
|
||||
bodyPC := bodyCfg.Models.Providers["openai"]
|
||||
bodyPC.APIBase = "https://new.example/v1"
|
||||
bodyPC.APIKey = "test-key"
|
||||
bodyCfg.Models.Providers["openai"] = bodyPC
|
||||
body, err := json.Marshal(bodyCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal body: %v", err)
|
||||
@@ -278,8 +282,8 @@ func TestHandleWebUIConfigRequiresConfirmForProviderAPIBaseChange(t *testing.T)
|
||||
if !strings.Contains(rec.Body.String(), `"requires_confirm":true`) {
|
||||
t.Fatalf("expected requires_confirm response, got: %s", rec.Body.String())
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), `providers.proxy.api_base`) {
|
||||
t.Fatalf("expected providers.proxy.api_base in changed_fields, got: %s", rec.Body.String())
|
||||
if !strings.Contains(rec.Body.String(), `models.providers.openai.api_base`) {
|
||||
t.Fatalf("expected models.providers.openai.api_base in changed_fields, got: %s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -291,7 +295,7 @@ func TestHandleWebUIConfigRequiresConfirmForCustomProviderSecretChange(t *testin
|
||||
|
||||
cfg := cfgpkg.DefaultConfig()
|
||||
cfg.Logging.Enabled = false
|
||||
cfg.Providers.Proxies["backup"] = cfgpkg.ProviderConfig{
|
||||
cfg.Models.Providers["backup"] = cfgpkg.ProviderConfig{
|
||||
APIBase: "https://backup.example/v1",
|
||||
APIKey: "old-secret",
|
||||
Models: []string{"backup-model"},
|
||||
@@ -304,7 +308,7 @@ func TestHandleWebUIConfigRequiresConfirmForCustomProviderSecretChange(t *testin
|
||||
|
||||
bodyCfg := cfgpkg.DefaultConfig()
|
||||
bodyCfg.Logging.Enabled = false
|
||||
bodyCfg.Providers.Proxies["backup"] = cfgpkg.ProviderConfig{
|
||||
bodyCfg.Models.Providers["backup"] = cfgpkg.ProviderConfig{
|
||||
APIBase: "https://backup.example/v1",
|
||||
APIKey: "new-secret",
|
||||
Models: []string{"backup-model"},
|
||||
@@ -331,8 +335,8 @@ func TestHandleWebUIConfigRequiresConfirmForCustomProviderSecretChange(t *testin
|
||||
if !strings.Contains(rec.Body.String(), `"requires_confirm":true`) {
|
||||
t.Fatalf("expected requires_confirm response, got: %s", rec.Body.String())
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), `providers.proxies.backup.api_key`) {
|
||||
t.Fatalf("expected providers.proxies.backup.api_key in changed_fields, got: %s", rec.Body.String())
|
||||
if !strings.Contains(rec.Body.String(), `models.providers.backup.api_key`) {
|
||||
t.Fatalf("expected models.providers.backup.api_key in changed_fields, got: %s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -16,16 +17,16 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Agents AgentsConfig `json:"agents"`
|
||||
Channels ChannelsConfig `json:"channels"`
|
||||
Providers ProvidersConfig `json:"providers"`
|
||||
Gateway GatewayConfig `json:"gateway"`
|
||||
Cron CronConfig `json:"cron"`
|
||||
Tools ToolsConfig `json:"tools"`
|
||||
Logging LoggingConfig `json:"logging"`
|
||||
Sentinel SentinelConfig `json:"sentinel"`
|
||||
Memory MemoryConfig `json:"memory"`
|
||||
mu sync.RWMutex
|
||||
Agents AgentsConfig `json:"agents"`
|
||||
Channels ChannelsConfig `json:"channels"`
|
||||
Models ModelsConfig `json:"models,omitempty"`
|
||||
Gateway GatewayConfig `json:"gateway"`
|
||||
Cron CronConfig `json:"cron"`
|
||||
Tools ToolsConfig `json:"tools"`
|
||||
Logging LoggingConfig `json:"logging"`
|
||||
Sentinel SentinelConfig `json:"sentinel"`
|
||||
Memory MemoryConfig `json:"memory"`
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type AgentsConfig struct {
|
||||
@@ -107,7 +108,7 @@ type SubagentToolsConfig struct {
|
||||
}
|
||||
|
||||
type SubagentRuntimeConfig struct {
|
||||
Proxy string `json:"proxy,omitempty"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TimeoutSec int `json:"timeout_sec,omitempty"`
|
||||
@@ -120,8 +121,7 @@ type SubagentRuntimeConfig struct {
|
||||
|
||||
type AgentDefaults struct {
|
||||
Workspace string `json:"workspace" env:"CLAWGO_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
Proxy string `json:"proxy" env:"CLAWGO_AGENTS_DEFAULTS_PROXY"`
|
||||
ProxyFallbacks []string `json:"proxy_fallbacks" env:"CLAWGO_AGENTS_DEFAULTS_PROXY_FALLBACKS"`
|
||||
Model AgentModelDefaults `json:"model,omitempty"`
|
||||
MaxTokens int `json:"max_tokens" env:"CLAWGO_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature float64 `json:"temperature" env:"CLAWGO_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"CLAWGO_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
@@ -131,6 +131,11 @@ type AgentDefaults struct {
|
||||
SummaryPolicy SystemSummaryPolicyConfig `json:"summary_policy"`
|
||||
}
|
||||
|
||||
type AgentModelDefaults struct {
|
||||
Primary string `json:"primary,omitempty" env:"CLAWGO_AGENTS_DEFAULTS_MODEL_PRIMARY"`
|
||||
Fallbacks []string `json:"fallbacks,omitempty" env:"CLAWGO_AGENTS_DEFAULTS_MODEL_FALLBACKS"`
|
||||
}
|
||||
|
||||
type HeartbeatConfig struct {
|
||||
Enabled bool `json:"enabled" env:"CLAWGO_AGENTS_DEFAULTS_HEARTBEAT_ENABLED"`
|
||||
EverySec int `json:"every_sec" env:"CLAWGO_AGENTS_DEFAULTS_HEARTBEAT_EVERY_SEC"`
|
||||
@@ -234,52 +239,8 @@ type DingTalkConfig struct {
|
||||
AllowFrom []string `json:"allow_from" env:"CLAWGO_CHANNELS_DINGTALK_ALLOW_FROM"`
|
||||
}
|
||||
|
||||
type ProvidersConfig struct {
|
||||
Proxy ProviderConfig `json:"proxy"`
|
||||
Proxies map[string]ProviderConfig `json:"proxies"`
|
||||
}
|
||||
|
||||
type providerProxyItem struct {
|
||||
Name string `json:"name"`
|
||||
ProviderConfig
|
||||
}
|
||||
|
||||
func (p *ProvidersConfig) UnmarshalJSON(data []byte) error {
|
||||
var tmp struct {
|
||||
Proxy ProviderConfig `json:"proxy"`
|
||||
Proxies json.RawMessage `json:"proxies"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &tmp); err != nil {
|
||||
return err
|
||||
}
|
||||
p.Proxy = tmp.Proxy
|
||||
p.Proxies = map[string]ProviderConfig{}
|
||||
if len(bytes.TrimSpace(tmp.Proxies)) == 0 || string(bytes.TrimSpace(tmp.Proxies)) == "null" {
|
||||
return nil
|
||||
}
|
||||
// Preferred format: object map
|
||||
var asMap map[string]ProviderConfig
|
||||
if err := json.Unmarshal(tmp.Proxies, &asMap); err == nil {
|
||||
for k, v := range asMap {
|
||||
if k == "" {
|
||||
continue
|
||||
}
|
||||
p.Proxies[k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Compatibility format: array [{name, ...provider fields...}]
|
||||
var asArr []providerProxyItem
|
||||
if err := json.Unmarshal(tmp.Proxies, &asArr); err == nil {
|
||||
for _, it := range asArr {
|
||||
if it.Name == "" {
|
||||
continue
|
||||
}
|
||||
p.Proxies[it.Name] = it.ProviderConfig
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("providers.proxies must be object map or array of {name,...}")
|
||||
type ModelsConfig struct {
|
||||
Providers map[string]ProviderConfig `json:"providers,omitempty"`
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
@@ -298,6 +259,7 @@ type ProviderConfig struct {
|
||||
|
||||
type ProviderOAuthConfig struct {
|
||||
Provider string `json:"provider,omitempty"`
|
||||
NetworkProxy string `json:"network_proxy,omitempty"`
|
||||
CredentialFile string `json:"credential_file,omitempty"`
|
||||
CredentialFiles []string `json:"credential_files,omitempty"`
|
||||
CallbackPort int `json:"callback_port,omitempty"`
|
||||
@@ -307,7 +269,6 @@ type ProviderOAuthConfig struct {
|
||||
TokenURL string `json:"token_url,omitempty"`
|
||||
RedirectURL string `json:"redirect_url,omitempty"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
HybridPriority string `json:"hybrid_priority,omitempty"`
|
||||
CooldownSec int `json:"cooldown_sec,omitempty"`
|
||||
RefreshScanSec int `json:"refresh_scan_sec,omitempty"`
|
||||
RefreshLeadSec int `json:"refresh_lead_sec,omitempty"`
|
||||
@@ -484,8 +445,7 @@ func DefaultConfig() *Config {
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
Workspace: filepath.Join(configDir, "workspace"),
|
||||
Proxy: "proxy",
|
||||
ProxyFallbacks: []string{},
|
||||
Model: AgentModelDefaults{Primary: "openai/gpt-5.4", Fallbacks: []string{}},
|
||||
MaxTokens: 8192,
|
||||
Temperature: 0.7,
|
||||
MaxToolIterations: 20,
|
||||
@@ -599,13 +559,14 @@ func DefaultConfig() *Config {
|
||||
AllowFrom: []string{},
|
||||
},
|
||||
},
|
||||
Providers: ProvidersConfig{
|
||||
Proxy: ProviderConfig{
|
||||
APIBase: "http://localhost:8080/v1",
|
||||
Models: []string{"glm-4.7"},
|
||||
TimeoutSec: 90,
|
||||
Models: ModelsConfig{
|
||||
Providers: map[string]ProviderConfig{
|
||||
"openai": {
|
||||
APIBase: "https://api.openai.com/v1",
|
||||
Models: []string{"gpt-5.4"},
|
||||
TimeoutSec: 90,
|
||||
},
|
||||
},
|
||||
Proxies: map[string]ProviderConfig{},
|
||||
},
|
||||
Gateway: GatewayConfig{
|
||||
Host: "0.0.0.0",
|
||||
@@ -699,6 +660,60 @@ func DefaultConfig() *Config {
|
||||
}
|
||||
}
|
||||
|
||||
func ParseProviderModelRef(raw string) (provider string, model string) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "", ""
|
||||
}
|
||||
if idx := strings.Index(trimmed, "/"); idx > 0 {
|
||||
return strings.TrimSpace(trimmed[:idx]), strings.TrimSpace(trimmed[idx+1:])
|
||||
}
|
||||
return "", trimmed
|
||||
}
|
||||
|
||||
func AllProviderConfigs(cfg *Config) map[string]ProviderConfig {
|
||||
out := map[string]ProviderConfig{}
|
||||
if cfg == nil {
|
||||
return out
|
||||
}
|
||||
for name, pc := range cfg.Models.Providers {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
out[trimmed] = pc
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func ProviderConfigByName(cfg *Config, name string) (ProviderConfig, bool) {
|
||||
if cfg == nil {
|
||||
return ProviderConfig{}, false
|
||||
}
|
||||
pc, ok := AllProviderConfigs(cfg)[strings.TrimSpace(name)]
|
||||
return pc, ok
|
||||
}
|
||||
|
||||
func ProviderExists(cfg *Config, name string) bool {
|
||||
_, ok := ProviderConfigByName(cfg, name)
|
||||
return ok
|
||||
}
|
||||
|
||||
func PrimaryProviderName(cfg *Config) string {
|
||||
if cfg == nil {
|
||||
return "openai"
|
||||
}
|
||||
if provider, _ := ParseProviderModelRef(cfg.Agents.Defaults.Model.Primary); provider != "" {
|
||||
return provider
|
||||
}
|
||||
for name := range cfg.Models.Providers {
|
||||
if trimmed := strings.TrimSpace(name); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return "openai"
|
||||
}
|
||||
|
||||
func generateGatewayToken() string {
|
||||
var buf [16]byte
|
||||
if _, err := rand.Read(buf[:]); err != nil {
|
||||
@@ -771,13 +786,19 @@ func (c *Config) WorkspacePath() string {
|
||||
func (c *Config) GetAPIKey() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.Providers.Proxy.APIKey
|
||||
if pc, ok := c.Models.Providers[PrimaryProviderName(c)]; ok {
|
||||
return pc.APIKey
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *Config) GetAPIBase() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.Providers.Proxy.APIBase
|
||||
if pc, ok := c.Models.Providers[PrimaryProviderName(c)]; ok {
|
||||
return pc.APIBase
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *Config) LogFilePath() string {
|
||||
|
||||
@@ -87,30 +87,33 @@ func Validate(cfg *Config) []error {
|
||||
}
|
||||
}
|
||||
|
||||
if len(cfg.Providers.Proxies) == 0 {
|
||||
errs = append(errs, validateProviderConfig("providers.proxy", cfg.Providers.Proxy)...)
|
||||
} else {
|
||||
for name, p := range cfg.Providers.Proxies {
|
||||
errs = append(errs, validateProviderConfig("providers.proxies."+name, p)...)
|
||||
for name, p := range cfg.Models.Providers {
|
||||
errs = append(errs, validateProviderConfig("models.providers."+name, p)...)
|
||||
}
|
||||
if len(cfg.Models.Providers) == 0 {
|
||||
errs = append(errs, fmt.Errorf("models.providers must contain at least one provider"))
|
||||
}
|
||||
for _, name := range cfg.Agents.Defaults.Model.Fallbacks {
|
||||
if !ProviderExists(cfg, name) {
|
||||
errs = append(errs, fmt.Errorf("agents.defaults.model.fallbacks contains unknown provider %q", name))
|
||||
}
|
||||
}
|
||||
if cfg.Agents.Defaults.Proxy != "" {
|
||||
if !providerExists(cfg, cfg.Agents.Defaults.Proxy) {
|
||||
errs = append(errs, fmt.Errorf("agents.defaults.proxy %q not found in providers", cfg.Agents.Defaults.Proxy))
|
||||
if primaryRef := strings.TrimSpace(cfg.Agents.Defaults.Model.Primary); primaryRef != "" {
|
||||
providerName, modelName := ParseProviderModelRef(primaryRef)
|
||||
if providerName == "" {
|
||||
providerName = PrimaryProviderName(cfg)
|
||||
}
|
||||
}
|
||||
for _, name := range cfg.Agents.Defaults.ProxyFallbacks {
|
||||
if !providerExists(cfg, name) {
|
||||
errs = append(errs, fmt.Errorf("agents.defaults.proxy_fallbacks contains unknown proxy %q", name))
|
||||
if !ProviderExists(cfg, providerName) {
|
||||
errs = append(errs, fmt.Errorf("agents.defaults.model.primary %q references unknown provider %q", primaryRef, providerName))
|
||||
}
|
||||
if strings.TrimSpace(modelName) == "" {
|
||||
errs = append(errs, fmt.Errorf("agents.defaults.model.primary must include a model, expected provider/model"))
|
||||
}
|
||||
}
|
||||
if cfg.Agents.Defaults.ContextCompaction.Enabled && cfg.Agents.Defaults.ContextCompaction.Mode == "responses_compact" {
|
||||
active := cfg.Agents.Defaults.Proxy
|
||||
if active == "" {
|
||||
active = "proxy"
|
||||
}
|
||||
if pc, ok := providerConfigByName(cfg, active); !ok || !pc.SupportsResponsesCompact {
|
||||
errs = append(errs, fmt.Errorf("context_compaction.mode=responses_compact requires active proxy %q with supports_responses_compact=true", active))
|
||||
active := PrimaryProviderName(cfg)
|
||||
if pc, ok := ProviderConfigByName(cfg, active); !ok || !pc.SupportsResponsesCompact {
|
||||
errs = append(errs, fmt.Errorf("context_compaction.mode=responses_compact requires active provider %q with supports_responses_compact=true", active))
|
||||
}
|
||||
}
|
||||
errs = append(errs, validateAgentRouter(cfg)...)
|
||||
@@ -474,8 +477,8 @@ func validateSubagents(cfg *Config) []error {
|
||||
errs = append(errs, fmt.Errorf("agents.subagents.%s.system_prompt_file must stay within workspace", id))
|
||||
}
|
||||
}
|
||||
if proxy := strings.TrimSpace(raw.Runtime.Proxy); proxy != "" && !providerExists(cfg, proxy) {
|
||||
errs = append(errs, fmt.Errorf("agents.subagents.%s.runtime.proxy %q not found in providers", id, proxy))
|
||||
if provider := strings.TrimSpace(raw.Runtime.Provider); provider != "" && !ProviderExists(cfg, provider) {
|
||||
errs = append(errs, fmt.Errorf("agents.subagents.%s.runtime.provider %q not found in providers", id, provider))
|
||||
}
|
||||
for _, sender := range raw.AcceptFrom {
|
||||
sender = strings.TrimSpace(sender)
|
||||
@@ -562,13 +565,6 @@ func validateProviderConfig(path string, p ProviderConfig) []error {
|
||||
errs = append(errs, fmt.Errorf("%s.oauth.provider is required when auth=hybrid", path))
|
||||
}
|
||||
}
|
||||
if p.OAuth.HybridPriority != "" {
|
||||
switch strings.ToLower(strings.TrimSpace(p.OAuth.HybridPriority)) {
|
||||
case "api_first", "oauth_first":
|
||||
default:
|
||||
errs = append(errs, fmt.Errorf("%s.oauth.hybrid_priority must be one of: api_first, oauth_first", path))
|
||||
}
|
||||
}
|
||||
if p.OAuth.CooldownSec < 0 {
|
||||
errs = append(errs, fmt.Errorf("%s.oauth.cooldown_sec must be >= 0", path))
|
||||
}
|
||||
@@ -587,25 +583,6 @@ func validateProviderConfig(path string, p ProviderConfig) []error {
|
||||
return errs
|
||||
}
|
||||
|
||||
func providerExists(cfg *Config, name string) bool {
|
||||
if name == "proxy" && cfg.Providers.Proxy.APIBase != "" {
|
||||
return true
|
||||
}
|
||||
if cfg.Providers.Proxies == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := cfg.Providers.Proxies[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func providerConfigByName(cfg *Config, name string) (ProviderConfig, bool) {
|
||||
if strings.TrimSpace(name) == "proxy" {
|
||||
return cfg.Providers.Proxy, true
|
||||
}
|
||||
pc, ok := cfg.Providers.Proxies[name]
|
||||
return pc, ok
|
||||
}
|
||||
|
||||
func validateNonEmptyStringList(path string, values []string) []error {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -34,7 +34,7 @@ func TestValidateSubagentsAllowsKnownPeers(t *testing.T) {
|
||||
AcceptFrom: []string{"main"},
|
||||
CanTalkTo: []string{"main"},
|
||||
Runtime: SubagentRuntimeConfig{
|
||||
Proxy: "proxy",
|
||||
Provider: "openai",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ func TestValidateSubagentsRejectsAbsolutePromptFile(t *testing.T) {
|
||||
Enabled: true,
|
||||
SystemPromptFile: "/tmp/AGENT.md",
|
||||
Runtime: SubagentRuntimeConfig{
|
||||
Proxy: "proxy",
|
||||
Provider: "openai",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -82,7 +82,7 @@ func TestValidateSubagentsRequiresPromptFileWhenEnabled(t *testing.T) {
|
||||
cfg.Agents.Subagents["coder"] = SubagentConfig{
|
||||
Enabled: true,
|
||||
Runtime: SubagentRuntimeConfig{
|
||||
Proxy: "proxy",
|
||||
Provider: "openai",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@ func TestValidateSubagentsRejectsInvalidNotifyMainPolicy(t *testing.T) {
|
||||
SystemPromptFile: "agents/coder/AGENT.md",
|
||||
NotifyMainPolicy: "loud",
|
||||
Runtime: SubagentRuntimeConfig{
|
||||
Proxy: "proxy",
|
||||
Provider: "openai",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -276,9 +276,11 @@ func TestValidateProviderOAuthAllowsEmptyModelsBeforeLogin(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.Providers.Proxy.Auth = "oauth"
|
||||
cfg.Providers.Proxy.Models = nil
|
||||
cfg.Providers.Proxy.OAuth = ProviderOAuthConfig{Provider: "codex"}
|
||||
pc := cfg.Models.Providers["openai"]
|
||||
pc.Auth = "oauth"
|
||||
pc.Models = nil
|
||||
pc.OAuth = ProviderOAuthConfig{Provider: "codex"}
|
||||
cfg.Models.Providers["openai"] = pc
|
||||
|
||||
if errs := Validate(cfg); len(errs) != 0 {
|
||||
t.Fatalf("expected oauth provider config to be valid before model sync, got %v", errs)
|
||||
@@ -289,9 +291,11 @@ func TestValidateProviderOAuthRequiresProviderName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.Providers.Proxy.Auth = "oauth"
|
||||
cfg.Providers.Proxy.Models = nil
|
||||
cfg.Providers.Proxy.OAuth = ProviderOAuthConfig{}
|
||||
pc := cfg.Models.Providers["openai"]
|
||||
pc.Auth = "oauth"
|
||||
pc.Models = nil
|
||||
pc.OAuth = ProviderOAuthConfig{}
|
||||
cfg.Models.Providers["openai"] = pc
|
||||
|
||||
errs := Validate(cfg)
|
||||
if len(errs) == 0 {
|
||||
@@ -299,7 +303,7 @@ func TestValidateProviderOAuthRequiresProviderName(t *testing.T) {
|
||||
}
|
||||
found := false
|
||||
for _, err := range errs {
|
||||
if strings.Contains(err.Error(), "providers.proxy.oauth.provider") {
|
||||
if strings.Contains(err.Error(), "models.providers.openai.oauth.provider") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
@@ -313,10 +317,12 @@ func TestValidateProviderHybridAllowsEmptyModels(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.Providers.Proxy.Auth = "hybrid"
|
||||
cfg.Providers.Proxy.APIKey = "sk-test"
|
||||
cfg.Providers.Proxy.Models = nil
|
||||
cfg.Providers.Proxy.OAuth = ProviderOAuthConfig{Provider: "codex"}
|
||||
pc := cfg.Models.Providers["openai"]
|
||||
pc.Auth = "hybrid"
|
||||
pc.APIKey = "sk-test"
|
||||
pc.Models = nil
|
||||
pc.OAuth = ProviderOAuthConfig{Provider: "codex"}
|
||||
cfg.Models.Providers["openai"] = pc
|
||||
|
||||
if errs := Validate(cfg); len(errs) != 0 {
|
||||
t.Fatalf("expected hybrid provider config to be valid before model sync, got %v", errs)
|
||||
@@ -327,10 +333,12 @@ func TestValidateProviderHybridRequiresOAuthProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.Providers.Proxy.Auth = "hybrid"
|
||||
cfg.Providers.Proxy.APIKey = "sk-test"
|
||||
cfg.Providers.Proxy.Models = nil
|
||||
cfg.Providers.Proxy.OAuth = ProviderOAuthConfig{}
|
||||
pc := cfg.Models.Providers["openai"]
|
||||
pc.Auth = "hybrid"
|
||||
pc.APIKey = "sk-test"
|
||||
pc.Models = nil
|
||||
pc.OAuth = ProviderOAuthConfig{}
|
||||
cfg.Models.Providers["openai"] = pc
|
||||
|
||||
errs := Validate(cfg)
|
||||
if len(errs) == 0 {
|
||||
@@ -338,7 +346,7 @@ func TestValidateProviderHybridRequiresOAuthProvider(t *testing.T) {
|
||||
}
|
||||
found := false
|
||||
for _, err := range errs {
|
||||
if strings.Contains(err.Error(), "providers.proxy.oauth.provider") {
|
||||
if strings.Contains(err.Error(), "models.providers.openai.oauth.provider") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
@@ -347,30 +355,3 @@ func TestValidateProviderHybridRequiresOAuthProvider(t *testing.T) {
|
||||
t.Fatalf("expected oauth.provider validation error, got %v", errs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateProviderHybridPriorityRejectsInvalidValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.Providers.Proxy.Auth = "hybrid"
|
||||
cfg.Providers.Proxy.APIKey = "sk-test"
|
||||
cfg.Providers.Proxy.OAuth = ProviderOAuthConfig{
|
||||
Provider: "codex",
|
||||
HybridPriority: "random_first",
|
||||
}
|
||||
|
||||
errs := Validate(cfg)
|
||||
if len(errs) == 0 {
|
||||
t.Fatalf("expected validation errors")
|
||||
}
|
||||
found := false
|
||||
for _, err := range errs {
|
||||
if strings.Contains(err.Error(), "oauth.hybrid_priority") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("expected oauth.hybrid_priority validation error, got %v", errs)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -16,16 +17,25 @@ type anthropicOAuthRoundTripper struct {
|
||||
connections map[string]*http2.ClientConn
|
||||
pending map[string]*sync.Cond
|
||||
dialer net.Dialer
|
||||
dialContext func(context.Context, string, string) (net.Conn, error)
|
||||
}
|
||||
|
||||
func newAnthropicOAuthHTTPClient(timeout time.Duration) *http.Client {
|
||||
func newAnthropicOAuthHTTPClient(timeout time.Duration, proxyURL string) (*http.Client, error) {
|
||||
rt, err := newAnthropicOAuthRoundTripper(proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: newAnthropicOAuthRoundTripper(),
|
||||
}
|
||||
Transport: rt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newAnthropicOAuthRoundTripper() *anthropicOAuthRoundTripper {
|
||||
func newAnthropicOAuthRoundTripper(proxyURL string) (*anthropicOAuthRoundTripper, error) {
|
||||
dialContext, err := proxyDialContext(proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &anthropicOAuthRoundTripper{
|
||||
connections: map[string]*http2.ClientConn{},
|
||||
pending: map[string]*sync.Cond{},
|
||||
@@ -33,7 +43,8 @@ func newAnthropicOAuthRoundTripper() *anthropicOAuthRoundTripper {
|
||||
Timeout: 15 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
dialContext: dialContext,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *anthropicOAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
@@ -89,7 +100,11 @@ func (t *anthropicOAuthRoundTripper) getOrCreateConnection(host, addr string) (*
|
||||
}
|
||||
|
||||
func (t *anthropicOAuthRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
|
||||
rawConn, err := t.dialer.Dial("tcp", addr)
|
||||
dialContext := t.dialContext
|
||||
if dialContext == nil {
|
||||
dialContext = t.dialer.DialContext
|
||||
}
|
||||
rawConn, err := dialContext(context.Background(), "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -665,7 +665,7 @@ func (p *HTTPProvider) postJSONStream(ctx context.Context, endpoint string, payl
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
applyAttemptAuth(req, attempt)
|
||||
|
||||
body, status, ctype, quotaHit, err := p.doStreamAttempt(req, onEvent)
|
||||
body, status, ctype, quotaHit, err := p.doStreamAttempt(req, attempt, onEvent)
|
||||
if err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
@@ -707,7 +707,7 @@ func (p *HTTPProvider) postJSON(ctx context.Context, endpoint string, payload in
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
applyAttemptAuth(req, attempt)
|
||||
|
||||
body, status, ctype, err := p.doJSONAttempt(req)
|
||||
body, status, ctype, err := p.doJSONAttempt(req, attempt)
|
||||
if err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
@@ -753,7 +753,7 @@ func (p *HTTPProvider) authAttempts(ctx context.Context) ([]authAttempt, error)
|
||||
for _, attempt := range attempts {
|
||||
oauthAttempts = append(oauthAttempts, authAttempt{session: attempt.Session, token: attempt.Token, kind: "oauth"})
|
||||
}
|
||||
if mode == "hybrid" && apiReady && p.oauth.cfg.HybridPriority != "oauth_first" {
|
||||
if mode == "hybrid" && apiReady {
|
||||
out = append(out, apiAttempt)
|
||||
}
|
||||
if len(attempts) == 0 {
|
||||
@@ -764,9 +764,6 @@ func (p *HTTPProvider) authAttempts(ctx context.Context) ([]authAttempt, error)
|
||||
return nil, fmt.Errorf("oauth session not found, run `clawgo provider login` first")
|
||||
}
|
||||
out = append(out, oauthAttempts...)
|
||||
if mode == "hybrid" && apiReady && p.oauth.cfg.HybridPriority == "oauth_first" {
|
||||
out = append(out, apiAttempt)
|
||||
}
|
||||
p.updateCandidateOrder(out)
|
||||
return out, nil
|
||||
}
|
||||
@@ -833,8 +830,19 @@ func applyAttemptAuth(req *http.Request, attempt authAttempt) {
|
||||
req.Header.Set("Authorization", "Bearer "+attempt.token)
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) doJSONAttempt(req *http.Request) ([]byte, int, string, error) {
|
||||
resp, err := p.httpClient.Do(req)
|
||||
func (p *HTTPProvider) httpClientForAttempt(attempt authAttempt) (*http.Client, error) {
|
||||
if attempt.kind == "oauth" && attempt.session != nil && p.oauth != nil {
|
||||
return p.oauth.httpClientForSession(attempt.session)
|
||||
}
|
||||
return p.httpClient, nil
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) doJSONAttempt(req *http.Request, attempt authAttempt) ([]byte, int, string, error) {
|
||||
client, err := p.httpClientForAttempt(attempt)
|
||||
if err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, "", fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
@@ -846,8 +854,12 @@ func (p *HTTPProvider) doJSONAttempt(req *http.Request) ([]byte, int, string, er
|
||||
return body, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil
|
||||
}
|
||||
|
||||
func (p *HTTPProvider) doStreamAttempt(req *http.Request, onEvent func(string)) ([]byte, int, string, bool, error) {
|
||||
resp, err := p.httpClient.Do(req)
|
||||
func (p *HTTPProvider) doStreamAttempt(req *http.Request, attempt authAttempt, onEvent func(string)) ([]byte, int, string, bool, error) {
|
||||
client, err := p.httpClientForAttempt(attempt)
|
||||
if err != nil {
|
||||
return nil, 0, "", false, err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, "", false, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
@@ -1437,17 +1449,10 @@ func buildProviderCandidateOrder(_ string, pc config.ProviderConfig, accounts []
|
||||
case "oauth":
|
||||
out = append(out, oauthAvailable...)
|
||||
case "hybrid":
|
||||
if strings.EqualFold(strings.TrimSpace(pc.OAuth.HybridPriority), "oauth_first") {
|
||||
out = append(out, oauthAvailable...)
|
||||
if apiCandidate.Target != "" && apiCandidate.Available {
|
||||
out = append(out, apiCandidate)
|
||||
}
|
||||
} else {
|
||||
if apiCandidate.Target != "" && apiCandidate.Available {
|
||||
out = append(out, apiCandidate)
|
||||
}
|
||||
out = append(out, oauthAvailable...)
|
||||
if apiCandidate.Target != "" && apiCandidate.Available {
|
||||
out = append(out, apiCandidate)
|
||||
}
|
||||
out = append(out, oauthAvailable...)
|
||||
case "none":
|
||||
default:
|
||||
if apiCandidate.Target != "" {
|
||||
@@ -2131,11 +2136,16 @@ func (p *HTTPProvider) BuildSummaryViaResponsesCompact(ctx context.Context, mode
|
||||
}
|
||||
|
||||
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
name := strings.TrimSpace(cfg.Agents.Defaults.Proxy)
|
||||
if name == "" {
|
||||
name = "proxy"
|
||||
name := config.PrimaryProviderName(cfg)
|
||||
provider, err := CreateProviderByName(cfg, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return CreateProviderByName(cfg, name)
|
||||
_, model := config.ParseProviderModelRef(cfg.Agents.Defaults.Model.Primary)
|
||||
if hp, ok := provider.(*HTTPProvider); ok && strings.TrimSpace(model) != "" {
|
||||
hp.defaultModel = strings.TrimSpace(model)
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error) {
|
||||
@@ -2219,48 +2229,12 @@ func ListProviderNames(cfg *config.Config) []string {
|
||||
}
|
||||
|
||||
func getAllProviderConfigs(cfg *config.Config) map[string]config.ProviderConfig {
|
||||
out := map[string]config.ProviderConfig{}
|
||||
if cfg == nil {
|
||||
return out
|
||||
}
|
||||
includeLegacyProxy := len(cfg.Providers.Proxies) == 0 || strings.TrimSpace(cfg.Agents.Defaults.Proxy) == "proxy" || containsStringTrimmed(cfg.Agents.Defaults.ProxyFallbacks, "proxy")
|
||||
if includeLegacyProxy && (cfg.Providers.Proxy.APIBase != "" || cfg.Providers.Proxy.APIKey != "" || cfg.Providers.Proxy.TimeoutSec > 0) {
|
||||
out["proxy"] = cfg.Providers.Proxy
|
||||
}
|
||||
for name, pc := range cfg.Providers.Proxies {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
out[trimmed] = pc
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func containsStringTrimmed(values []string, target string) bool {
|
||||
t := strings.TrimSpace(target)
|
||||
for _, v := range values {
|
||||
if strings.TrimSpace(v) == t {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return config.AllProviderConfigs(cfg)
|
||||
}
|
||||
|
||||
func getProviderConfigByName(cfg *config.Config, name string) (config.ProviderConfig, error) {
|
||||
if cfg == nil {
|
||||
return config.ProviderConfig{}, fmt.Errorf("nil config")
|
||||
if pc, ok := config.ProviderConfigByName(cfg, name); ok {
|
||||
return pc, nil
|
||||
}
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
return config.ProviderConfig{}, fmt.Errorf("empty provider name")
|
||||
}
|
||||
if trimmed == "proxy" {
|
||||
return cfg.Providers.Proxy, nil
|
||||
}
|
||||
pc, ok := cfg.Providers.Proxies[trimmed]
|
||||
if !ok {
|
||||
return config.ProviderConfig{}, fmt.Errorf("provider %q not found", trimmed)
|
||||
}
|
||||
return pc, nil
|
||||
return config.ProviderConfig{}, fmt.Errorf("provider %q not found", strings.TrimSpace(name))
|
||||
}
|
||||
|
||||
156
pkg/providers/http_proxy.go
Normal file
156
pkg/providers/http_proxy.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
stdtls "crypto/tls"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
xproxy "golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
func normalizeOptionalProxyURL(raw string) (string, error) {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
return "", nil
|
||||
}
|
||||
if !strings.Contains(value, "://") {
|
||||
value = "http://" + value
|
||||
}
|
||||
parsed, err := url.Parse(value)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid network proxy: %w", err)
|
||||
}
|
||||
if parsed.Scheme == "" || parsed.Host == "" {
|
||||
return "", fmt.Errorf("invalid network proxy: host is required")
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(parsed.Scheme)) {
|
||||
case "http", "https", "socks5", "socks5h":
|
||||
return parsed.String(), nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid network proxy: unsupported scheme %q", parsed.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func maskedProxyURL(raw string) string {
|
||||
normalized, err := normalizeOptionalProxyURL(raw)
|
||||
if err != nil || normalized == "" {
|
||||
return ""
|
||||
}
|
||||
parsed, err := url.Parse(normalized)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if parsed.User != nil {
|
||||
username := parsed.User.Username()
|
||||
if username != "" {
|
||||
parsed.User = url.UserPassword(username, "***")
|
||||
} else {
|
||||
parsed.User = url.User("***")
|
||||
}
|
||||
}
|
||||
return parsed.String()
|
||||
}
|
||||
|
||||
func proxyDialContext(proxyRaw string) (func(context.Context, string, string) (net.Conn, error), error) {
|
||||
normalized, err := normalizeOptionalProxyURL(proxyRaw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if normalized == "" {
|
||||
dialer := &net.Dialer{Timeout: 15 * time.Second, KeepAlive: 30 * time.Second}
|
||||
return dialer.DialContext, nil
|
||||
}
|
||||
parsed, err := url.Parse(normalized)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(parsed.Scheme)) {
|
||||
case "socks5", "socks5h":
|
||||
base := &net.Dialer{Timeout: 15 * time.Second, KeepAlive: 30 * time.Second}
|
||||
dialer, err := xproxy.FromURL(parsed, base)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("configure socks proxy failed: %w", err)
|
||||
}
|
||||
if ctxDialer, ok := dialer.(xproxy.ContextDialer); ok {
|
||||
return ctxDialer.DialContext, nil
|
||||
}
|
||||
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
type dialResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
ch := make(chan dialResult, 1)
|
||||
go func() {
|
||||
conn, err := dialer.Dial(network, addr)
|
||||
ch <- dialResult{conn: conn, err: err}
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case res := <-ch:
|
||||
return res.conn, res.err
|
||||
}
|
||||
}, nil
|
||||
case "http", "https":
|
||||
base := &net.Dialer{Timeout: 15 * time.Second, KeepAlive: 30 * time.Second}
|
||||
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
conn, err := base.DialContext(ctx, "tcp", parsed.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.EqualFold(parsed.Scheme, "https") {
|
||||
tlsConn := stdtls.Client(conn, &stdtls.Config{ServerName: parsed.Hostname()})
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
conn = tlsConn
|
||||
}
|
||||
connectReq := buildProxyConnectRequest(parsed, addr)
|
||||
if _, err := conn.Write([]byte(connectReq)); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
br := bufio.NewReader(conn)
|
||||
resp, err := http.ReadResponse(br, &http.Request{Method: http.MethodConnect})
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("proxy connect failed: status=%d", resp.StatusCode)
|
||||
}
|
||||
return conn, nil
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid network proxy: unsupported scheme %q", parsed.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func buildProxyConnectRequest(proxyURL *url.URL, targetAddr string) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("CONNECT ")
|
||||
b.WriteString(targetAddr)
|
||||
b.WriteString(" HTTP/1.1\r\nHost: ")
|
||||
b.WriteString(targetAddr)
|
||||
b.WriteString("\r\n")
|
||||
if proxyURL != nil && proxyURL.User != nil {
|
||||
username := proxyURL.User.Username()
|
||||
password, _ := proxyURL.User.Password()
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
|
||||
b.WriteString("Proxy-Authorization: Basic ")
|
||||
b.WriteString(encoded)
|
||||
b.WriteString("\r\n")
|
||||
}
|
||||
b.WriteString("\r\n")
|
||||
return b.String()
|
||||
}
|
||||
@@ -118,6 +118,7 @@ type oauthSession struct {
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
DeviceID string `json:"device_id,omitempty"`
|
||||
ResourceURL string `json:"resource_url,omitempty"`
|
||||
NetworkProxy string `json:"network_proxy,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
Token map[string]any `json:"token,omitempty"`
|
||||
CooldownUntil string `json:"-"`
|
||||
@@ -143,7 +144,6 @@ type oauthConfig struct {
|
||||
Scopes []string
|
||||
RefreshScan time.Duration
|
||||
RefreshLead time.Duration
|
||||
HybridPriority string
|
||||
Cooldown time.Duration
|
||||
FlowKind string
|
||||
TokenStyle string
|
||||
@@ -154,7 +154,10 @@ type oauthConfig struct {
|
||||
type oauthManager struct {
|
||||
providerName string
|
||||
cfg oauthConfig
|
||||
timeout time.Duration
|
||||
httpClient *http.Client
|
||||
clientMu sync.Mutex
|
||||
clients map[string]*http.Client
|
||||
mu sync.Mutex
|
||||
cached []*oauthSession
|
||||
cooldowns map[string]time.Time
|
||||
@@ -198,6 +201,7 @@ type OAuthLoginOptions struct {
|
||||
NoBrowser bool
|
||||
Reader io.Reader
|
||||
AccountLabel string
|
||||
NetworkProxy string
|
||||
}
|
||||
|
||||
type OAuthPendingFlow struct {
|
||||
@@ -218,6 +222,7 @@ type OAuthSessionInfo struct {
|
||||
CredentialFile string
|
||||
ProjectID string
|
||||
AccountLabel string
|
||||
NetworkProxy string
|
||||
}
|
||||
|
||||
type OAuthAccountInfo struct {
|
||||
@@ -230,6 +235,7 @@ type OAuthAccountInfo struct {
|
||||
AccountLabel string `json:"account_label,omitempty"`
|
||||
DeviceID string `json:"device_id,omitempty"`
|
||||
ResourceURL string `json:"resource_url,omitempty"`
|
||||
NetworkProxy string `json:"network_proxy,omitempty"`
|
||||
CooldownUntil string `json:"cooldown_until,omitempty"`
|
||||
FailureCount int `json:"failure_count,omitempty"`
|
||||
LastFailure string `json:"last_failure,omitempty"`
|
||||
@@ -277,6 +283,7 @@ func (m *OAuthLoginManager) Login(ctx context.Context, apiBase string, opts OAut
|
||||
CredentialFile: session.FilePath,
|
||||
ProjectID: session.ProjectID,
|
||||
AccountLabel: sessionLabel(session),
|
||||
NetworkProxy: maskedProxyURL(session.NetworkProxy),
|
||||
}, models, nil
|
||||
}
|
||||
|
||||
@@ -288,13 +295,20 @@ func (m *OAuthLoginManager) CredentialFile() string {
|
||||
}
|
||||
|
||||
func (m *OAuthLoginManager) StartManualFlow() (*OAuthPendingFlow, error) {
|
||||
return m.StartManualFlowWithOptions(OAuthLoginOptions{})
|
||||
}
|
||||
|
||||
func (m *OAuthLoginManager) StartManualFlowWithOptions(opts OAuthLoginOptions) (*OAuthPendingFlow, error) {
|
||||
if m == nil || m.manager == nil {
|
||||
return nil, fmt.Errorf("oauth login manager not configured")
|
||||
}
|
||||
if m.manager.cfg.FlowKind == oauthFlowDevice {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
return m.manager.startDeviceFlow(ctx)
|
||||
return m.manager.startDeviceFlow(ctx, opts)
|
||||
}
|
||||
if _, err := normalizeOptionalProxyURL(opts.NetworkProxy); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pkceVerifier, pkceChallenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
@@ -350,6 +364,7 @@ func (m *OAuthLoginManager) CompleteManualFlowWithOptions(ctx context.Context, a
|
||||
CredentialFile: session.FilePath,
|
||||
ProjectID: session.ProjectID,
|
||||
AccountLabel: sessionLabel(session),
|
||||
NetworkProxy: maskedProxyURL(session.NetworkProxy),
|
||||
}, models, nil
|
||||
}
|
||||
|
||||
@@ -371,6 +386,7 @@ func (m *OAuthLoginManager) ImportAuthJSONWithOptions(ctx context.Context, apiBa
|
||||
CredentialFile: session.FilePath,
|
||||
ProjectID: session.ProjectID,
|
||||
AccountLabel: sessionLabel(session),
|
||||
NetworkProxy: maskedProxyURL(session.NetworkProxy),
|
||||
}, models, nil
|
||||
}
|
||||
|
||||
@@ -399,6 +415,7 @@ func (m *OAuthLoginManager) ListAccounts() ([]OAuthAccountInfo, error) {
|
||||
AccountLabel: sessionLabel(session),
|
||||
DeviceID: session.DeviceID,
|
||||
ResourceURL: session.ResourceURL,
|
||||
NetworkProxy: maskedProxyURL(session.NetworkProxy),
|
||||
CooldownUntil: session.CooldownUntil,
|
||||
FailureCount: session.FailureCount,
|
||||
LastFailure: session.LastFailure,
|
||||
@@ -436,6 +453,7 @@ func (m *OAuthLoginManager) RefreshAccount(ctx context.Context, credentialFile s
|
||||
AccountLabel: sessionLabel(refreshed),
|
||||
DeviceID: refreshed.DeviceID,
|
||||
ResourceURL: refreshed.ResourceURL,
|
||||
NetworkProxy: maskedProxyURL(refreshed.NetworkProxy),
|
||||
CooldownUntil: refreshed.CooldownUntil,
|
||||
FailureCount: refreshed.FailureCount,
|
||||
LastFailure: refreshed.LastFailure,
|
||||
@@ -506,10 +524,16 @@ func newOAuthManager(pc config.ProviderConfig, timeout time.Duration) (*oauthMan
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client, err := newOAuthHTTPClient(resolved.Provider, timeout, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bgCtx, bgCancel := context.WithCancel(context.Background())
|
||||
manager := &oauthManager{
|
||||
cfg: resolved,
|
||||
httpClient: newOAuthHTTPClient(resolved.Provider, timeout),
|
||||
timeout: timeout,
|
||||
httpClient: client,
|
||||
clients: map[string]*http.Client{"": client},
|
||||
cooldowns: map[string]time.Time{},
|
||||
bgCtx: bgCtx,
|
||||
bgCancel: bgCancel,
|
||||
@@ -518,11 +542,32 @@ func newOAuthManager(pc config.ProviderConfig, timeout time.Duration) (*oauthMan
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
func newOAuthHTTPClient(provider string, timeout time.Duration) *http.Client {
|
||||
if provider == defaultClaudeOAuthProvider {
|
||||
return newAnthropicOAuthHTTPClient(timeout)
|
||||
func newOAuthHTTPClient(provider string, timeout time.Duration, proxyURL string) (*http.Client, error) {
|
||||
normalizedProxy, err := normalizeOptionalProxyURL(proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &http.Client{Timeout: timeout}
|
||||
if provider == defaultClaudeOAuthProvider {
|
||||
return newAnthropicOAuthHTTPClient(timeout, normalizedProxy)
|
||||
}
|
||||
if normalizedProxy == "" {
|
||||
return &http.Client{Timeout: timeout}, nil
|
||||
}
|
||||
parsed, err := url.Parse(normalizedProxy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(parsed),
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 15 * time.Second,
|
||||
ExpectContinueTimeout: time.Second,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func resolveOAuthConfig(pc config.ProviderConfig) (oauthConfig, error) {
|
||||
@@ -545,7 +590,6 @@ func resolveOAuthConfig(pc config.ProviderConfig) (oauthConfig, error) {
|
||||
TokenURL: strings.TrimSpace(pc.OAuth.TokenURL),
|
||||
RedirectURL: strings.TrimSpace(pc.OAuth.RedirectURL),
|
||||
Scopes: trimNonEmptyStrings(pc.OAuth.Scopes),
|
||||
HybridPriority: normalizeHybridPriority(pc.OAuth.HybridPriority),
|
||||
Cooldown: durationFromSeconds(pc.OAuth.CooldownSec, 15*time.Minute),
|
||||
RefreshScan: durationFromSeconds(pc.OAuth.RefreshScanSec, 10*time.Minute),
|
||||
RefreshLead: defaultRefreshLead(provider, pc.OAuth.RefreshLeadSec),
|
||||
@@ -626,9 +670,6 @@ func resolveOAuthConfig(pc config.ProviderConfig) (oauthConfig, error) {
|
||||
cfg.CredentialFiles = uniqueStrings(append([]string{cfg.CredentialFile}, cfg.CredentialFiles...))
|
||||
cfg.CredentialFile = cfg.CredentialFiles[0]
|
||||
}
|
||||
if cfg.HybridPriority == "" {
|
||||
cfg.HybridPriority = "api_first"
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
@@ -675,7 +716,12 @@ func (m *oauthManager) models(ctx context.Context, apiBase string) ([]string, er
|
||||
seen := map[string]struct{}{}
|
||||
var lastErr error
|
||||
for _, attempt := range attempts {
|
||||
models, err := fetchOpenAIModels(ctx, m.httpClient, apiBase, attempt.Token)
|
||||
client, err := m.httpClientForSession(attempt.Session)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
models, err := fetchOpenAIModels(ctx, client, apiBase, attempt.Token)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
@@ -704,7 +750,7 @@ func (m *oauthManager) login(ctx context.Context, apiBase string, opts OAuthLogi
|
||||
return nil, nil, fmt.Errorf("oauth manager not configured")
|
||||
}
|
||||
if m.cfg.FlowKind == oauthFlowDevice {
|
||||
flow, err := m.startDeviceFlow(ctx)
|
||||
flow, err := m.startDeviceFlow(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -741,14 +787,18 @@ func (m *oauthManager) login(ctx context.Context, apiBase string, opts OAuthLogi
|
||||
}
|
||||
|
||||
func (m *oauthManager) completeLogin(ctx context.Context, apiBase, pkceVerifier string, callback *oauthCallbackResult, state string, opts OAuthLoginOptions) (*oauthSession, []string, error) {
|
||||
session, err := m.exchangeCode(ctx, callback.Code, pkceVerifier, state)
|
||||
session, err := m.exchangeCode(ctx, callback.Code, pkceVerifier, state, opts.NetworkProxy)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := m.applyAccountLabel(session, opts); err != nil {
|
||||
if err := m.applySessionOptions(session, opts); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
models, _ := fetchOpenAIModels(ctx, m.httpClient, apiBase, session.AccessToken)
|
||||
client, err := m.httpClientForSession(session)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
models, _ := fetchOpenAIModels(ctx, client, apiBase, session.AccessToken)
|
||||
if len(models) > 0 {
|
||||
session.Models = append([]string(nil), models...)
|
||||
}
|
||||
@@ -788,10 +838,14 @@ func (m *oauthManager) importSession(ctx context.Context, apiBase, fileName stri
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := m.applyAccountLabel(session, opts); err != nil {
|
||||
if err := m.applySessionOptions(session, opts); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
models, _ := fetchOpenAIModels(ctx, m.httpClient, apiBase, session.AccessToken)
|
||||
client, err := m.httpClientForSession(session)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
models, _ := fetchOpenAIModels(ctx, client, apiBase, session.AccessToken)
|
||||
if len(models) > 0 {
|
||||
session.Models = append([]string(nil), models...)
|
||||
}
|
||||
@@ -1094,7 +1148,7 @@ func (m *oauthManager) authorizationURL(state, pkceChallenge string) string {
|
||||
return m.cfg.AuthURL + "?" + v.Encode()
|
||||
}
|
||||
|
||||
func (m *oauthManager) exchangeCode(ctx context.Context, code, verifier, state string) (*oauthSession, error) {
|
||||
func (m *oauthManager) exchangeCode(ctx context.Context, code, verifier, state string, proxyURL string) (*oauthSession, error) {
|
||||
switch m.cfg.Provider {
|
||||
case defaultClaudeOAuthProvider:
|
||||
reqBody := map[string]any{
|
||||
@@ -1105,7 +1159,7 @@ func (m *oauthManager) exchangeCode(ctx context.Context, code, verifier, state s
|
||||
"redirect_uri": m.cfg.RedirectURL,
|
||||
"code_verifier": verifier,
|
||||
}
|
||||
raw, err := m.doJSONTokenRequest(ctx, reqBody)
|
||||
raw, err := m.doJSONTokenRequest(ctx, reqBody, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1122,7 +1176,7 @@ func (m *oauthManager) exchangeCode(ctx context.Context, code, verifier, state s
|
||||
if m.cfg.ClientSecret != "" {
|
||||
form.Set("client_secret", m.cfg.ClientSecret)
|
||||
}
|
||||
raw, err := m.doFormTokenRequest(ctx, form)
|
||||
raw, err := m.doFormTokenRequest(ctx, form, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1159,7 +1213,7 @@ func (m *oauthManager) refreshSessionData(ctx context.Context, session *oauthSes
|
||||
"client_id": m.cfg.ClientID,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": session.RefreshToken,
|
||||
})
|
||||
}, session.NetworkProxy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1185,7 +1239,7 @@ func (m *oauthManager) refreshSessionData(ctx context.Context, session *oauthSes
|
||||
if len(m.cfg.Scopes) > 0 && m.cfg.Provider != defaultQwenOAuthProvider && m.cfg.Provider != defaultKimiOAuthProvider {
|
||||
form.Set("scope", strings.Join(m.cfg.Scopes, " "))
|
||||
}
|
||||
raw, err := m.doFormTokenRequest(ctx, form)
|
||||
raw, err := m.doFormTokenRequest(ctx, form, session.NetworkProxy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1217,7 +1271,7 @@ func (m *oauthManager) refreshGoogleTokenSession(ctx context.Context, session *o
|
||||
if clientSecret != "" {
|
||||
form.Set("client_secret", clientSecret)
|
||||
}
|
||||
raw, err := m.doFormTokenRequestURL(ctx, tokenURL, form)
|
||||
raw, err := m.doFormTokenRequestURL(ctx, tokenURL, form, session.NetworkProxy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1254,13 +1308,13 @@ func (m *oauthManager) enrichSession(ctx context.Context, session *oauthSession)
|
||||
switch m.cfg.Provider {
|
||||
case defaultAntigravityOAuthProvider, defaultGeminiOAuthProvider:
|
||||
if strings.TrimSpace(session.Email) == "" && m.cfg.UserInfoURL != "" && session.AccessToken != "" {
|
||||
email, err := m.fetchUserEmail(ctx, session.AccessToken)
|
||||
email, err := m.fetchUserEmail(ctx, session.AccessToken, session.NetworkProxy)
|
||||
if err == nil {
|
||||
session.Email = email
|
||||
}
|
||||
}
|
||||
if m.cfg.Provider == defaultAntigravityOAuthProvider && strings.TrimSpace(session.ProjectID) == "" && session.AccessToken != "" {
|
||||
projectID, err := m.fetchAntigravityProjectID(ctx, session.AccessToken)
|
||||
projectID, err := m.fetchAntigravityProjectID(ctx, session.AccessToken, session.NetworkProxy)
|
||||
if err == nil {
|
||||
session.ProjectID = projectID
|
||||
}
|
||||
@@ -1413,6 +1467,22 @@ func (m *oauthManager) applyAccountLabel(session *oauthSession, opts OAuthLoginO
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *oauthManager) applySessionOptions(session *oauthSession, opts OAuthLoginOptions) error {
|
||||
if err := m.applyAccountLabel(session, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
if session == nil {
|
||||
return fmt.Errorf("oauth session is nil")
|
||||
}
|
||||
proxyURL := firstNonEmpty(opts.NetworkProxy, session.NetworkProxy)
|
||||
normalized, err := normalizeOptionalProxyURL(proxyURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
session.NetworkProxy = normalized
|
||||
return nil
|
||||
}
|
||||
|
||||
func sessionLabel(session *oauthSession) string {
|
||||
if session == nil {
|
||||
return ""
|
||||
@@ -1420,6 +1490,34 @@ func sessionLabel(session *oauthSession) string {
|
||||
return firstNonEmpty(session.Email, session.AccountID, session.ProjectID)
|
||||
}
|
||||
|
||||
func (m *oauthManager) httpClientForSession(session *oauthSession) (*http.Client, error) {
|
||||
if session == nil {
|
||||
return m.httpClient, nil
|
||||
}
|
||||
return m.httpClientForProxy(session.NetworkProxy)
|
||||
}
|
||||
|
||||
func (m *oauthManager) httpClientForProxy(proxyURL string) (*http.Client, error) {
|
||||
normalized, err := normalizeOptionalProxyURL(proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if normalized == "" {
|
||||
return m.httpClient, nil
|
||||
}
|
||||
m.clientMu.Lock()
|
||||
defer m.clientMu.Unlock()
|
||||
if client, ok := m.clients[normalized]; ok && client != nil {
|
||||
return client, nil
|
||||
}
|
||||
client, err := newOAuthHTTPClient(m.cfg.Provider, m.timeout, normalized)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.clients[normalized] = client
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (m *oauthManager) allocateCredentialPathLocked(session *oauthSession) (string, error) {
|
||||
files := m.credentialFiles()
|
||||
if len(files) == 1 {
|
||||
@@ -1517,6 +1615,7 @@ func mergeOAuthSession(prev, next *oauthSession) *oauthSession {
|
||||
merged.ProjectID = firstNonEmpty(next.ProjectID, prev.ProjectID)
|
||||
merged.DeviceID = firstNonEmpty(next.DeviceID, prev.DeviceID)
|
||||
merged.ResourceURL = firstNonEmpty(next.ResourceURL, prev.ResourceURL)
|
||||
merged.NetworkProxy = firstNonEmpty(next.NetworkProxy, prev.NetworkProxy)
|
||||
merged.Scope = firstNonEmpty(next.Scope, prev.Scope)
|
||||
merged.Models = append([]string(nil), prev.Models...)
|
||||
if len(next.Models) > 0 {
|
||||
@@ -1828,17 +1927,6 @@ func durationFromSeconds(value int, fallback time.Duration) time.Duration {
|
||||
return time.Duration(value) * time.Second
|
||||
}
|
||||
|
||||
func normalizeHybridPriority(value string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "", "api_first":
|
||||
return "api_first"
|
||||
case "oauth_first":
|
||||
return "oauth_first"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func parseOAuthCallbackURL(raw string) (*oauthCallbackResult, error) {
|
||||
parsed, err := url.Parse(strings.TrimSpace(raw))
|
||||
if err != nil {
|
||||
@@ -1902,6 +1990,7 @@ func parseImportedOAuthSession(provider, fileName string, data []byte) (*oauthSe
|
||||
session.ProjectID = firstNonEmpty(asString(raw["project_id"]), asString(raw["projectId"]))
|
||||
session.DeviceID = firstNonEmpty(asString(raw["device_id"]), asString(raw["deviceId"]))
|
||||
session.ResourceURL = asString(raw["resource_url"])
|
||||
session.NetworkProxy = firstNonEmpty(asString(raw["network_proxy"]), asString(raw["proxy_url"]), asString(raw["http_proxy"]))
|
||||
session.Scope = firstNonEmpty(asString(raw["scope"]), asString(raw["scopes"]))
|
||||
if models := stringSliceFromAny(raw["models"]); len(models) > 0 {
|
||||
session.Models = models
|
||||
@@ -1916,6 +2005,7 @@ func parseImportedOAuthSession(provider, fileName string, data []byte) (*oauthSe
|
||||
)
|
||||
session.ProjectID = firstNonEmpty(session.ProjectID, asString(session.Token["project_id"]), asString(session.Token["projectId"]))
|
||||
session.DeviceID = firstNonEmpty(session.DeviceID, asString(session.Token["device_id"]), asString(session.Token["deviceId"]))
|
||||
session.NetworkProxy = firstNonEmpty(session.NetworkProxy, asString(session.Token["network_proxy"]), asString(session.Token["proxy_url"]), asString(session.Token["http_proxy"]))
|
||||
session.Scope = firstNonEmpty(session.Scope, asString(session.Token["scope"]), asString(session.Token["scopes"]))
|
||||
}
|
||||
if claims := parseJWTClaims(session.IDToken); len(claims) > 0 {
|
||||
@@ -2043,21 +2133,21 @@ func defaultInt(value, fallback int) int {
|
||||
return fallback
|
||||
}
|
||||
|
||||
func (m *oauthManager) doFormTokenRequest(ctx context.Context, form url.Values) (map[string]any, error) {
|
||||
return m.doFormTokenRequestURL(ctx, m.cfg.TokenURL, form)
|
||||
func (m *oauthManager) doFormTokenRequest(ctx context.Context, form url.Values, proxyURL string) (map[string]any, error) {
|
||||
return m.doFormTokenRequestURL(ctx, m.cfg.TokenURL, form, proxyURL)
|
||||
}
|
||||
|
||||
func (m *oauthManager) doFormTokenRequestURL(ctx context.Context, endpoint string, form url.Values) (map[string]any, error) {
|
||||
func (m *oauthManager) doFormTokenRequestURL(ctx context.Context, endpoint string, form url.Values, proxyURL string) (map[string]any, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
return m.doJSONRequest(req, "oauth token request")
|
||||
return m.doJSONRequest(req, "oauth token request", proxyURL)
|
||||
}
|
||||
|
||||
func (m *oauthManager) doJSONTokenRequest(ctx context.Context, payload map[string]any) (map[string]any, error) {
|
||||
func (m *oauthManager) doJSONTokenRequest(ctx context.Context, payload map[string]any, proxyURL string) (map[string]any, error) {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2068,11 +2158,15 @@ func (m *oauthManager) doJSONTokenRequest(ctx context.Context, payload map[strin
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
return m.doJSONRequest(req, "oauth token request")
|
||||
return m.doJSONRequest(req, "oauth token request", proxyURL)
|
||||
}
|
||||
|
||||
func (m *oauthManager) doJSONRequest(req *http.Request, label string) (map[string]any, error) {
|
||||
resp, err := m.httpClient.Do(req)
|
||||
func (m *oauthManager) doJSONRequest(req *http.Request, label, proxyURL string) (map[string]any, error) {
|
||||
client, err := m.httpClientForProxy(proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s failed: %w", label, err)
|
||||
}
|
||||
@@ -2091,13 +2185,13 @@ func (m *oauthManager) doJSONRequest(req *http.Request, label string) (map[strin
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
func (m *oauthManager) fetchUserEmail(ctx context.Context, token string) (string, error) {
|
||||
func (m *oauthManager) fetchUserEmail(ctx context.Context, token, proxyURL string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, m.cfg.UserInfoURL, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token))
|
||||
raw, err := m.doJSONRequest(req, "oauth userinfo request")
|
||||
raw, err := m.doJSONRequest(req, "oauth userinfo request", proxyURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -2108,7 +2202,7 @@ func (m *oauthManager) fetchUserEmail(ctx context.Context, token string) (string
|
||||
return email, nil
|
||||
}
|
||||
|
||||
func (m *oauthManager) fetchAntigravityProjectID(ctx context.Context, token string) (string, error) {
|
||||
func (m *oauthManager) fetchAntigravityProjectID(ctx context.Context, token, proxyURL string) (string, error) {
|
||||
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", defaultAntigravityAPIEndpoint, defaultAntigravityAPIVersion)
|
||||
body := `{"metadata":{"ideType":"ANTIGRAVITY","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}}`
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(body))
|
||||
@@ -2120,7 +2214,11 @@ func (m *oauthManager) fetchAntigravityProjectID(ctx context.Context, token stri
|
||||
req.Header.Set("User-Agent", defaultAntigravityAPIUserAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", defaultAntigravityAPIClient)
|
||||
req.Header.Set("Client-Metadata", defaultAntigravityClientMeta)
|
||||
resp, err := m.httpClient.Do(req)
|
||||
client, err := m.httpClientForProxy(proxyURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -2148,7 +2246,7 @@ func (m *oauthManager) fetchAntigravityProjectID(ctx context.Context, token stri
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
func (m *oauthManager) startDeviceFlow(ctx context.Context) (*OAuthPendingFlow, error) {
|
||||
func (m *oauthManager) startDeviceFlow(ctx context.Context, opts OAuthLoginOptions) (*OAuthPendingFlow, error) {
|
||||
if m.cfg.FlowKind != oauthFlowDevice {
|
||||
return nil, fmt.Errorf("oauth provider %s does not use device flow", m.cfg.Provider)
|
||||
}
|
||||
@@ -2165,7 +2263,7 @@ func (m *oauthManager) startDeviceFlow(ctx context.Context) (*OAuthPendingFlow,
|
||||
}
|
||||
form.Set("code_challenge", challenge)
|
||||
form.Set("code_challenge_method", "S256")
|
||||
raw, err := m.doFormDeviceRequest(ctx, m.cfg.DeviceCodeURL, form)
|
||||
raw, err := m.doFormDeviceRequest(ctx, m.cfg.DeviceCodeURL, form, opts.NetworkProxy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2184,7 +2282,7 @@ func (m *oauthManager) startDeviceFlow(ctx context.Context) (*OAuthPendingFlow,
|
||||
Instructions: "Open the verification URL, finish authorization, then click continue to let the gateway poll for tokens.",
|
||||
}, nil
|
||||
case defaultKimiOAuthProvider:
|
||||
raw, err := m.doFormDeviceRequest(ctx, m.cfg.DeviceCodeURL, form)
|
||||
raw, err := m.doFormDeviceRequest(ctx, m.cfg.DeviceCodeURL, form, opts.NetworkProxy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2206,7 +2304,7 @@ func (m *oauthManager) startDeviceFlow(ctx context.Context) (*OAuthPendingFlow,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *oauthManager) doFormDeviceRequest(ctx context.Context, endpoint string, form url.Values) (map[string]any, error) {
|
||||
func (m *oauthManager) doFormDeviceRequest(ctx context.Context, endpoint string, form url.Values, proxyURL string) (map[string]any, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2220,7 +2318,7 @@ func (m *oauthManager) doFormDeviceRequest(ctx context.Context, endpoint string,
|
||||
req.Header.Set("X-Msh-Device-Model", runtime.GOOS+" "+runtime.GOARCH)
|
||||
req.Header.Set("X-Msh-Device-Id", randomDeviceID())
|
||||
}
|
||||
return m.doJSONRequest(req, "oauth device request")
|
||||
return m.doJSONRequest(req, "oauth device request", proxyURL)
|
||||
}
|
||||
|
||||
func parseDeviceFlowPayload(raw map[string]any) (*oauthDeviceCodeResponse, error) {
|
||||
@@ -2254,14 +2352,18 @@ func randomDeviceID() string {
|
||||
}
|
||||
|
||||
func (m *oauthManager) completeDeviceFlow(ctx context.Context, apiBase string, flow *OAuthPendingFlow, opts OAuthLoginOptions) (*oauthSession, []string, error) {
|
||||
session, err := m.pollDeviceToken(ctx, flow)
|
||||
session, err := m.pollDeviceToken(ctx, flow, opts.NetworkProxy)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := m.applyAccountLabel(session, opts); err != nil {
|
||||
if err := m.applySessionOptions(session, opts); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
models, _ := fetchOpenAIModels(ctx, m.httpClient, apiBase, session.AccessToken)
|
||||
client, err := m.httpClientForSession(session)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
models, _ := fetchOpenAIModels(ctx, client, apiBase, session.AccessToken)
|
||||
if len(models) > 0 {
|
||||
session.Models = append([]string(nil), models...)
|
||||
}
|
||||
@@ -2281,7 +2383,7 @@ func (m *oauthManager) completeDeviceFlow(ctx context.Context, apiBase string, f
|
||||
return session, models, nil
|
||||
}
|
||||
|
||||
func (m *oauthManager) pollDeviceToken(ctx context.Context, flow *OAuthPendingFlow) (*oauthSession, error) {
|
||||
func (m *oauthManager) pollDeviceToken(ctx context.Context, flow *OAuthPendingFlow, proxyURL string) (*oauthSession, error) {
|
||||
if flow == nil || strings.TrimSpace(flow.DeviceCode) == "" {
|
||||
return nil, fmt.Errorf("oauth device flow missing device code")
|
||||
}
|
||||
@@ -2306,7 +2408,7 @@ func (m *oauthManager) pollDeviceToken(ctx context.Context, flow *OAuthPendingFl
|
||||
if flow.PKCEVerifier != "" {
|
||||
form.Set("code_verifier", flow.PKCEVerifier)
|
||||
}
|
||||
raw, err := m.doFormTokenRequest(ctx, form)
|
||||
raw, err := m.doFormTokenRequest(ctx, form, proxyURL)
|
||||
if err == nil {
|
||||
session, convErr := sessionFromTokenPayload(m.cfg.Provider, raw)
|
||||
if convErr != nil {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@@ -386,6 +387,118 @@ func TestResolveOAuthConfigAppliesProviderRefreshLeadDefaults(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPProviderOAuthSessionProxyRoutesRefreshAndResponses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var refreshCalls int32
|
||||
var responseCalls int32
|
||||
target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/oauth/token":
|
||||
atomic.AddInt32(&refreshCalls, 1)
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatalf("parse token form failed: %v", err)
|
||||
}
|
||||
if got := r.Form.Get("grant_type"); got != "refresh_token" {
|
||||
t.Fatalf("unexpected grant_type: %s", got)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"access_token":"proxied-fresh-token","refresh_token":"refresh-token","expires_in":3600}`))
|
||||
case "/v1/responses":
|
||||
atomic.AddInt32(&responseCalls, 1)
|
||||
if got := r.Header.Get("Authorization"); got != "Bearer proxied-fresh-token" {
|
||||
t.Fatalf("unexpected authorization header: %s", got)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"status":"completed","output_text":"ok-via-proxy"}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer target.Close()
|
||||
|
||||
var proxyCalls int32
|
||||
proxyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddInt32(&proxyCalls, 1)
|
||||
targetURL := r.URL.String()
|
||||
if !strings.HasPrefix(targetURL, "http://") && !strings.HasPrefix(targetURL, "https://") {
|
||||
targetURL = target.URL + r.URL.Path
|
||||
if rawQuery := strings.TrimSpace(r.URL.RawQuery); rawQuery != "" {
|
||||
targetURL += "?" + rawQuery
|
||||
}
|
||||
}
|
||||
req, err := http.NewRequestWithContext(r.Context(), r.Method, targetURL, r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("create proxied request failed: %v", err)
|
||||
}
|
||||
req.Header = r.Header.Clone()
|
||||
resp, err := http.DefaultTransport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("proxy round trip failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
for key, values := range resp.Header {
|
||||
for _, value := range values {
|
||||
w.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
_, _ = io.Copy(w, resp.Body)
|
||||
}))
|
||||
defer proxyServer.Close()
|
||||
|
||||
credFile := filepath.Join(t.TempDir(), "proxied.json")
|
||||
raw, err := json.Marshal(oauthSession{
|
||||
Provider: "codex",
|
||||
AccessToken: "expired-token",
|
||||
RefreshToken: "refresh-token",
|
||||
Expire: time.Now().Add(-time.Hour).Format(time.RFC3339),
|
||||
NetworkProxy: proxyServer.URL,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal session failed: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(credFile, raw, 0o600); err != nil {
|
||||
t.Fatalf("write credential file failed: %v", err)
|
||||
}
|
||||
|
||||
pc := config.ProviderConfig{
|
||||
APIBase: target.URL + "/v1",
|
||||
Auth: "oauth",
|
||||
TimeoutSec: 5,
|
||||
OAuth: config.ProviderOAuthConfig{
|
||||
Provider: "codex",
|
||||
CredentialFile: credFile,
|
||||
ClientID: "test-client",
|
||||
TokenURL: target.URL + "/oauth/token",
|
||||
AuthURL: target.URL + "/oauth/authorize",
|
||||
},
|
||||
}
|
||||
oauth, err := newOAuthManager(pc, 5*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("new oauth manager failed: %v", err)
|
||||
}
|
||||
defer oauth.bgCancel()
|
||||
|
||||
provider := NewHTTPProvider("test-oauth", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth)
|
||||
resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("chat failed: %v", err)
|
||||
}
|
||||
if resp.Content != "ok-via-proxy" {
|
||||
t.Fatalf("unexpected response content: %q", resp.Content)
|
||||
}
|
||||
if atomic.LoadInt32(&refreshCalls) != 1 {
|
||||
t.Fatalf("expected one refresh call, got %d", refreshCalls)
|
||||
}
|
||||
if atomic.LoadInt32(&responseCalls) != 1 {
|
||||
t.Fatalf("expected one response call, got %d", responseCalls)
|
||||
}
|
||||
if got := atomic.LoadInt32(&proxyCalls); got < 2 {
|
||||
t.Fatalf("expected proxy to receive refresh and response requests, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthImportGeminiNestedTokenRefreshesWithTokenMetadata(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -547,7 +660,7 @@ func TestQwenDeviceFlowRequiresAccountLabelWhenEmailMissing(t *testing.T) {
|
||||
}
|
||||
defer manager.bgCancel()
|
||||
|
||||
flow, err := manager.startDeviceFlow(context.Background())
|
||||
flow, err := manager.startDeviceFlow(context.Background(), OAuthLoginOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("start device flow failed: %v", err)
|
||||
}
|
||||
@@ -734,7 +847,7 @@ func TestOAuthDeviceFlowQwenManualCompletes(t *testing.T) {
|
||||
t.Fatalf("new oauth manager failed: %v", err)
|
||||
}
|
||||
|
||||
flow, err := manager.startDeviceFlow(context.Background())
|
||||
flow, err := manager.startDeviceFlow(context.Background(), OAuthLoginOptions{})
|
||||
if err != nil {
|
||||
t.Fatalf("start device flow failed: %v", err)
|
||||
}
|
||||
@@ -879,7 +992,6 @@ func TestHTTPProviderHybridOAuthFirstUsesOAuthBeforeAPIKey(t *testing.T) {
|
||||
CredentialFile: credFile,
|
||||
TokenURL: server.URL + "/oauth/token",
|
||||
AuthURL: server.URL + "/oauth/authorize",
|
||||
HybridPriority: "oauth_first",
|
||||
},
|
||||
}
|
||||
oauth, err := newOAuthManager(pc, 5*time.Second)
|
||||
@@ -891,11 +1003,11 @@ func TestHTTPProviderHybridOAuthFirstUsesOAuthBeforeAPIKey(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("chat failed: %v", err)
|
||||
}
|
||||
if resp.Content != "ok-from-oauth" {
|
||||
if resp.Content != "ok-from-api" {
|
||||
t.Fatalf("unexpected response content: %q", resp.Content)
|
||||
}
|
||||
if atomic.LoadInt32(&oauthCalls) != 1 || atomic.LoadInt32(&apiKeyCalls) != 0 {
|
||||
t.Fatalf("expected oauth first only, got api=%d oauth=%d", apiKeyCalls, oauthCalls)
|
||||
if atomic.LoadInt32(&apiKeyCalls) != 1 || atomic.LoadInt32(&oauthCalls) != 0 {
|
||||
t.Fatalf("expected api key first only, got api=%d oauth=%d", apiKeyCalls, oauthCalls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1187,7 +1299,6 @@ func TestProviderRuntimeSnapshotIncludesCandidateOrderAndLastSuccess(t *testing.
|
||||
OAuth: config.ProviderOAuthConfig{
|
||||
Provider: "codex",
|
||||
CredentialFile: credFile,
|
||||
HybridPriority: "api_first",
|
||||
},
|
||||
}
|
||||
ConfigureProviderRuntime(name, pc)
|
||||
@@ -1206,8 +1317,8 @@ func TestProviderRuntimeSnapshotIncludesCandidateOrderAndLastSuccess(t *testing.
|
||||
provider.markAttemptSuccess(attempts[1])
|
||||
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{name: pc},
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{name: pc},
|
||||
},
|
||||
}
|
||||
snapshot := GetProviderRuntimeSnapshot(cfg)
|
||||
@@ -1267,8 +1378,8 @@ func TestConfigureProviderRuntimeLoadsPersistedEvents(t *testing.T) {
|
||||
})
|
||||
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
name: {
|
||||
APIBase: "https://example.com/v1",
|
||||
Auth: "bearer",
|
||||
@@ -1357,7 +1468,6 @@ func TestUpdateCandidateOrderRecordsSchedulerChange(t *testing.T) {
|
||||
OAuth: config.ProviderOAuthConfig{
|
||||
Provider: "codex",
|
||||
CredentialFile: credFile,
|
||||
HybridPriority: "api_first",
|
||||
},
|
||||
}
|
||||
manager, err := newOAuthManager(pc, 5*time.Second)
|
||||
@@ -1412,8 +1522,8 @@ func TestGetProviderRuntimeViewFiltersEvents(t *testing.T) {
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
name: {APIBase: "https://example.com/v1", Auth: "hybrid", APIKey: "api-key"},
|
||||
},
|
||||
},
|
||||
@@ -1463,8 +1573,8 @@ func TestGetProviderRuntimeViewCursorPagination(t *testing.T) {
|
||||
}
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
name: {APIBase: "https://example.com/v1", Auth: "oauth", OAuth: config.ProviderOAuthConfig{Provider: "codex"}},
|
||||
},
|
||||
},
|
||||
@@ -1503,8 +1613,8 @@ func TestGetProviderRuntimeViewSortAscending(t *testing.T) {
|
||||
}
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
name: {APIBase: "https://example.com/v1", Auth: "oauth", OAuth: config.ProviderOAuthConfig{Provider: "codex"}},
|
||||
},
|
||||
},
|
||||
@@ -1542,8 +1652,8 @@ func TestGetProviderRuntimeViewFiltersByHealthAndCooldown(t *testing.T) {
|
||||
}
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
name: {APIBase: "https://example.com/v1", Auth: "bearer", APIKey: "api-key"},
|
||||
},
|
||||
},
|
||||
@@ -1594,8 +1704,8 @@ func TestGetProviderRuntimeSummaryFlagsUnhealthyProviders(t *testing.T) {
|
||||
}
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
name: {APIBase: "https://example.com/v1", Auth: "bearer", APIKey: "api-key"},
|
||||
},
|
||||
},
|
||||
@@ -1643,8 +1753,8 @@ func TestGetProviderRuntimeSummaryMarksRecentErrorsAsDegraded(t *testing.T) {
|
||||
}
|
||||
providerRuntimeRegistry.mu.Unlock()
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
name: {APIBase: "https://example.com/v1", Auth: "oauth", OAuth: config.ProviderOAuthConfig{Provider: "codex"}},
|
||||
},
|
||||
},
|
||||
@@ -1688,8 +1798,8 @@ func TestGetProviderRuntimeSummaryIncludesOAuthAccountMetadata(t *testing.T) {
|
||||
t.Fatalf("write session failed: %v", err)
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
"qwen-summary": {
|
||||
APIBase: "https://example.com/v1",
|
||||
Auth: "oauth",
|
||||
@@ -1746,8 +1856,8 @@ func TestRefreshProviderRuntimeNowSupportsOnlyExpiring(t *testing.T) {
|
||||
|
||||
name := "runtime-refresh-provider"
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
name: {
|
||||
APIBase: server.URL + "/v1",
|
||||
Auth: "oauth",
|
||||
@@ -1807,8 +1917,8 @@ func TestRerankProviderRuntimeUpdatesCandidateOrder(t *testing.T) {
|
||||
}
|
||||
name := "rerank-runtime-provider"
|
||||
cfg := &config.Config{
|
||||
Providers: config.ProvidersConfig{
|
||||
Proxies: map[string]config.ProviderConfig{
|
||||
Models: config.ModelsConfig{
|
||||
Providers: map[string]config.ProviderConfig{
|
||||
name: {
|
||||
APIKey: "api-key",
|
||||
APIBase: "https://example.com/v1",
|
||||
@@ -1817,7 +1927,6 @@ func TestRerankProviderRuntimeUpdatesCandidateOrder(t *testing.T) {
|
||||
OAuth: config.ProviderOAuthConfig{
|
||||
Provider: "codex",
|
||||
CredentialFile: credFile,
|
||||
HybridPriority: "oauth_first",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1827,8 +1936,8 @@ func TestRerankProviderRuntimeUpdatesCandidateOrder(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("rerank provider runtime failed: %v", err)
|
||||
}
|
||||
if len(order) == 0 || order[0].Kind != "oauth" {
|
||||
t.Fatalf("expected oauth-first rerank result, got %#v", order)
|
||||
if len(order) == 0 || order[0].Kind != "api_key" {
|
||||
t.Fatalf("expected api-key-first rerank result, got %#v", order)
|
||||
}
|
||||
snapshot := GetProviderRuntimeSnapshot(cfg)
|
||||
items, _ := snapshot["items"].([]map[string]interface{})
|
||||
@@ -1836,7 +1945,7 @@ func TestRerankProviderRuntimeUpdatesCandidateOrder(t *testing.T) {
|
||||
t.Fatalf("expected one runtime item, got %#v", snapshot)
|
||||
}
|
||||
snapshotOrder, _ := items[0]["candidate_order"].([]providerRuntimeCandidate)
|
||||
if len(snapshotOrder) == 0 || snapshotOrder[0].Kind != "oauth" {
|
||||
t.Fatalf("expected oauth-first candidate order, got %#v", items[0]["candidate_order"])
|
||||
if len(snapshotOrder) == 0 || snapshotOrder[0].Kind != "api_key" {
|
||||
t.Fatalf("expected api-key-first candidate order, got %#v", items[0]["candidate_order"])
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user