From 1c0e463d075914059f2f2f9d4de34c40e8cf3d4d Mon Sep 17 00:00:00 2001 From: lpf Date: Wed, 11 Mar 2026 15:47:49 +0800 Subject: [PATCH] Add OAuth provider runtime and providers UI --- Makefile | 4 +- README.md | 24 + README_EN.md | 24 + cmd/cmd_config.go | 166 +- cmd/cmd_gateway.go | 30 +- cmd/sentinel_notify.go | 84 + cmd/sentinel_notify_test.go | 85 + config.example.json | 112 + go.mod | 1 + go.sum | 2 + pkg/agent/loop.go | 18 + pkg/agent/loop_audit_test.go | 48 + pkg/agent/session_planner.go | 305 ++- pkg/agent/session_planner_test.go | 166 ++ pkg/api/server.go | 513 ++++ pkg/config/config.go | 23 + pkg/config/validate.go | 42 +- pkg/config/validate_test.go | 132 +- pkg/providers/anthropic_transport.go | 107 + pkg/providers/http_provider.go | 1289 ++++++++- pkg/providers/oauth.go | 2334 +++++++++++++++++ pkg/providers/oauth_test.go | 1842 +++++++++++++ webui/src/App.tsx | 2 + webui/src/components/FormControls.tsx | 115 + webui/src/components/GlobalDialog.tsx | 6 +- webui/src/components/RecursiveConfig.tsx | 27 +- webui/src/components/Sidebar.tsx | 11 +- .../components/config/ConfigPageChrome.tsx | 148 ++ .../config/GatewayConfigSection.tsx | 276 ++ .../config/ProviderConfigSection.tsx | 558 ++++ webui/src/components/config/configUtils.ts | 115 + .../config/useConfigGatewayActions.ts | 66 + .../components/config/useConfigNavigation.ts | 64 + .../config/useConfigProviderActions.ts | 340 +++ .../config/useConfigRuntimeView.tsx | 47 + .../components/config/useConfigSaveAction.ts | 79 + webui/src/context/AppContext.tsx | 10 + webui/src/i18n/index.ts | 146 +- webui/src/pages/ChannelSettings.tsx | 60 +- webui/src/pages/Chat.tsx | 24 +- webui/src/pages/Config.tsx | 628 +---- webui/src/pages/Cron.tsx | 57 +- webui/src/pages/EKG.tsx | 6 +- webui/src/pages/LogCodes.tsx | 6 +- webui/src/pages/MCP.tsx | 61 +- webui/src/pages/Memory.tsx | 4 +- webui/src/pages/NodeArtifacts.tsx | 19 +- webui/src/pages/Nodes.tsx | 29 +- webui/src/pages/Providers.tsx | 263 ++ webui/src/pages/Skills.tsx | 15 +- webui/src/pages/SubagentProfiles.tsx | 130 +- webui/src/pages/TaskAudit.tsx | 10 +- 52 files changed, 9772 insertions(+), 901 deletions(-) create mode 100644 cmd/sentinel_notify.go create mode 100644 cmd/sentinel_notify_test.go create mode 100644 pkg/agent/loop_audit_test.go create mode 100644 pkg/providers/anthropic_transport.go create mode 100644 pkg/providers/oauth.go create mode 100644 pkg/providers/oauth_test.go create mode 100644 webui/src/components/FormControls.tsx create mode 100644 webui/src/components/config/ConfigPageChrome.tsx create mode 100644 webui/src/components/config/GatewayConfigSection.tsx create mode 100644 webui/src/components/config/ProviderConfigSection.tsx create mode 100644 webui/src/components/config/configUtils.ts create mode 100644 webui/src/components/config/useConfigGatewayActions.ts create mode 100644 webui/src/components/config/useConfigNavigation.ts create mode 100644 webui/src/components/config/useConfigProviderActions.ts create mode 100644 webui/src/components/config/useConfigRuntimeView.tsx create mode 100644 webui/src/components/config/useConfigSaveAction.ts create mode 100644 webui/src/pages/Providers.tsx diff --git a/Makefile b/Makefile index b9d19b4..2410736 100644 --- a/Makefile +++ b/Makefile @@ -438,8 +438,8 @@ deps: run: build @$(BUILD_DIR)/$(BINARY_NAME) $(ARGS) -## dev: Run the local gateway in foreground for debugging -dev: sync-embed-workspace +## dev: Build WebUI, sync workspace, and run the local gateway in foreground for debugging +dev: build-webui sync-embed-workspace @if [ ! -f "$(DEV_CONFIG)" ]; then \ echo "✗ Missing config file: $(DEV_CONFIG)"; \ echo " Override with: make dev DEV_CONFIG=/path/to/config.json"; \ diff --git a/README.md b/README.md index d2558ff..d6211f7 100644 --- a/README.md +++ b/README.md @@ -108,6 +108,30 @@ clawgo onboard clawgo provider ``` +如果服务商使用 OAuth 登录,例如 `Codex`、`Anthropic`、`Antigravity`、`Gemini CLI`、`Kimi`、`Qwen`: + +```bash +clawgo provider +clawgo provider login codex-oauth +clawgo provider login codex-oauth --manual +``` + +登录完成后会把 OAuth 凭证保存到本地,并自动同步该账号可用模型,后续可直接作为普通 provider 使用。 +回调型 OAuth(如 `codex` / `anthropic` / `antigravity` / `gemini`)在云服务器场景下可使用 `--manual`:服务端打印授权链接,你在桌面浏览器登录后,把最终回调 URL 粘贴回终端即可完成换取 token。 +设备码型 OAuth(如 `kimi` / `qwen`)会直接打印验证链接和用户码,桌面浏览器完成授权后,网关会自动轮询换取 token,无需回填 callback URL。 +对同一个 provider 重复执行 `clawgo provider login codex-oauth --manual` 会追加多个 OAuth 账号;当某个账号额度耗尽或触发限流时,会自动切换到下一个已登录账号重试。 +WebUI 也支持发起 OAuth 登录、回填 callback URL、设备码确认、上传 `auth.json`、查看账号列表、手动刷新和删除账号。 + +如果你同时有 `API key` 和 `OAuth` 账号,推荐直接把同一个 provider 配成 `auth: "hybrid"`: + +- 优先使用 `api_key` +- 当 `api_key` 触发额度不足、429、限流等错误时,自动切到该 provider 下的 OAuth 账号池 +- OAuth 账号仍然支持多账号轮换、后台预刷新、`auth.json` 导入和 WebUI 管理 +- `oauth.hybrid_priority` 可选 `api_first` / `oauth_first` +- `oauth.cooldown_sec` 可控制某个 OAuth 账号被限流后暂时熔断多久,默认 `900` +- provider runtime 面板会显示当前候选池排序、最近一次成功命中的凭证,以及最近命中/错误历史 +- 如需在重启后保留 runtime 历史,可给 provider 配置 `runtime_persist`、`runtime_history_file`、`runtime_history_max` + ### 4. 启动 交互模式: diff --git a/README_EN.md b/README_EN.md index 58a80e4..e7562b7 100644 --- a/README_EN.md +++ b/README_EN.md @@ -99,6 +99,30 @@ clawgo onboard clawgo provider ``` +For OAuth-backed providers such as `Codex`, `Anthropic`, `Antigravity`, `Gemini CLI`, `Kimi`, and `Qwen`: + +```bash +clawgo provider +clawgo provider login codex-oauth +clawgo provider login codex-oauth --manual +``` + +After login, clawgo stores the OAuth session locally and syncs the models available to that account so the provider can be used directly. +Use `--manual` on a cloud server for callback-based OAuth (`codex`, `anthropic`, `antigravity`, `gemini`): clawgo prints the auth URL, you complete login in a desktop browser, then paste the final callback URL back into the terminal. +Device-flow OAuth (`kimi`, `qwen`) prints the verification URL and user code, then clawgo polls automatically after authorization without requiring a callback URL to be pasted back. +Repeat `clawgo provider login codex-oauth --manual` on the same provider to add multiple OAuth accounts; when one account hits quota or rate limits, clawgo automatically retries with the next logged-in account. +The WebUI can also start OAuth login, accept callback URL pasteback, confirm device-flow authorization, import `auth.json`, list accounts, refresh accounts, and delete accounts. + +If you have both an `API key` and OAuth accounts for the same upstream, prefer configuring that provider as `auth: "hybrid"`: + +- it uses `api_key` first +- when the API key hits quota/rate-limit style failures, it automatically falls back to the provider's OAuth account pool +- OAuth accounts still keep multi-account rotation, background pre-refresh, `auth.json` import, and WebUI management +- `oauth.hybrid_priority` supports `api_first` or `oauth_first` +- `oauth.cooldown_sec` controls how long a rate-limited OAuth account stays out of rotation; default is `900` +- the provider runtime panel shows current candidate ordering, the most recent successful credential, and recent hit/error history +- to persist runtime history across restarts, configure `runtime_persist`, `runtime_history_file`, and `runtime_history_max` on the provider + ### 4. Start Interactive mode: diff --git a/cmd/cmd_config.go b/cmd/cmd_config.go index 8e99fb1..4ab3b17 100644 --- a/cmd/cmd_config.go +++ b/cmd/cmd_config.go @@ -2,15 +2,18 @@ package main import ( "bufio" + "context" "encoding/json" "fmt" "os" "sort" "strconv" "strings" + "time" "github.com/YspCoder/clawgo/pkg/config" "github.com/YspCoder/clawgo/pkg/configops" + "github.com/YspCoder/clawgo/pkg/providers" ) func configCmd() { @@ -172,6 +175,16 @@ func configCheckCmd() { } func providerCmd() { + if len(os.Args) >= 3 { + switch strings.TrimSpace(os.Args[2]) { + case "login": + providerLoginCmd() + return + case "configure": + // Continue into the interactive editor below. + } + } + cfg, err := loadConfig() if err != nil { fmt.Printf("Error loading config: %v\n", err) @@ -222,10 +235,19 @@ func providerCmd() { if models := parseCSV(modelsRaw); len(models) > 0 { pc.Models = models } - pc.Auth = promptLine(reader, "auth (bearer/oauth/none)", pc.Auth) + pc.Auth = promptLine(reader, "auth (bearer/oauth/hybrid/none)", pc.Auth) timeoutRaw := promptLine(reader, "timeout_sec", fmt.Sprintf("%d", pc.TimeoutSec)) pc.TimeoutSec = parseIntOrDefault(timeoutRaw, pc.TimeoutSec) pc.SupportsResponsesCompact = promptBool(reader, "supports_responses_compact", pc.SupportsResponsesCompact) + if strings.EqualFold(strings.TrimSpace(pc.Auth), "oauth") || strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") { + pc.OAuth.Provider = promptLine(reader, "oauth.provider", firstNonEmptyString(pc.OAuth.Provider, "codex")) + pc.OAuth.CredentialFile = promptLine(reader, "oauth.credential_file", pc.OAuth.CredentialFile) + pc.OAuth.CallbackPort = parseIntOrDefault(promptLine(reader, "oauth.callback_port", fmt.Sprintf("%d", defaultInt(pc.OAuth.CallbackPort, 1455))), defaultInt(pc.OAuth.CallbackPort, 1455)) + if strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") { + pc.OAuth.HybridPriority = promptLine(reader, "oauth.hybrid_priority (api_first/oauth_first)", firstNonEmptyString(pc.OAuth.HybridPriority, "api_first")) + } + pc.OAuth.CooldownSec = parseIntOrDefault(promptLine(reader, "oauth.cooldown_sec", fmt.Sprintf("%d", defaultInt(pc.OAuth.CooldownSec, 900))), defaultInt(pc.OAuth.CooldownSec, 900)) + } setProviderConfigByName(cfg, providerName, pc) @@ -278,6 +300,119 @@ func providerCmd() { fmt.Println("鉁?Gateway hot reload signal sent") } +func providerLoginCmd() { + cfg, err := loadConfig() + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + os.Exit(1) + } + + providerName := strings.TrimSpace(cfg.Agents.Defaults.Proxy) + if providerName == "" { + providerName = "proxy" + } + manual := false + noBrowser := false + accountLabel := "" + for i := 3; i < len(os.Args); i++ { + arg := strings.TrimSpace(os.Args[i]) + switch arg { + case "--manual": + manual = true + case "--no-browser": + noBrowser = true + case "--label": + if i+1 < len(os.Args) { + i++ + accountLabel = strings.TrimSpace(os.Args[i]) + } + case "": + default: + if strings.HasPrefix(arg, "--label=") { + accountLabel = strings.TrimSpace(strings.TrimPrefix(arg, "--label=")) + continue + } + if !strings.HasPrefix(arg, "-") { + providerName = arg + } + } + } + + pc := providerConfigByName(cfg, providerName) + if !strings.EqualFold(strings.TrimSpace(pc.Auth), "oauth") && !strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") { + fmt.Printf("Provider %s is not configured with auth=oauth/hybrid\n", providerName) + os.Exit(1) + } + if manual { + noBrowser = true + } + if manual && strings.TrimSpace(pc.OAuth.RedirectURL) == "" && pc.OAuth.CallbackPort <= 0 { + pc.OAuth.CallbackPort = 1455 + } + + timeout := pc.TimeoutSec + if timeout <= 0 { + timeout = 90 + } + oauth, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second) + if err != nil { + fmt.Printf("Error preparing oauth login: %v\n", err) + os.Exit(1) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + session, models, err := oauth.Login(ctx, pc.APIBase, providers.OAuthLoginOptions{ + Manual: manual, + NoBrowser: noBrowser, + Reader: os.Stdin, + AccountLabel: accountLabel, + }) + if err != nil { + fmt.Printf("OAuth login failed: %v\n", err) + os.Exit(1) + } + + if len(models) > 0 { + pc.Models = models + } + if session.CredentialFile != "" { + pc.OAuth.CredentialFile = session.CredentialFile + pc.OAuth.CredentialFiles = appendUniqueCSV(pc.OAuth.CredentialFiles, session.CredentialFile) + } else if pc.OAuth.CredentialFile == "" { + pc.OAuth.CredentialFile = oauth.CredentialFile() + pc.OAuth.CredentialFiles = appendUniqueCSV(pc.OAuth.CredentialFiles, pc.OAuth.CredentialFile) + } + setProviderConfigByName(cfg, providerName, pc) + + if err := config.SaveConfig(getConfigPath(), cfg); err != nil { + fmt.Printf("Error saving config: %v\n", err) + os.Exit(1) + } + + fmt.Printf("OAuth login succeeded for provider %s\n", providerName) + if manual { + fmt.Println("Mode: manual callback URL paste") + } else if noBrowser { + fmt.Println("Mode: local callback listener without auto-opening browser") + } + if session.Email != "" { + fmt.Printf("Account: %s\n", session.Email) + } + fmt.Printf("Credential file: %s\n", firstNonEmptyString(session.CredentialFile, oauth.CredentialFile())) + if len(pc.OAuth.CredentialFiles) > 1 { + fmt.Printf("OAuth accounts: %d\n", len(pc.OAuth.CredentialFiles)) + } + if len(models) > 0 { + fmt.Printf("Discovered models: %s\n", strings.Join(models, ", ")) + } + if running, reloadErr := triggerGatewayReload(); reloadErr == nil { + fmt.Println("Gateway hot reload signal sent") + } else if running { + fmt.Printf("Hot reload not applied: %v\n", reloadErr) + } +} + func providerNames(cfg *config.Config) []string { names := []string{"proxy"} for k := range cfg.Providers.Proxies { @@ -370,6 +505,35 @@ func parseCSV(raw string) []string { return out } +func defaultInt(value, fallback int) int { + if value > 0 { + return value + } + return fallback +} + +func firstNonEmptyString(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} + +func appendUniqueCSV(values []string, value string) []string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return values + } + for _, item := range values { + if strings.TrimSpace(item) == trimmed { + return values + } + } + return append(values, trimmed) +} + func parseIntOrDefault(raw string, def int) int { raw = strings.TrimSpace(raw) if raw == "" { diff --git a/cmd/cmd_gateway.go b/cmd/cmd_gateway.go index fdbf2b1..89d6bac 100644 --- a/cmd/cmd_gateway.go +++ b/cmd/cmd_gateway.go @@ -75,15 +75,7 @@ func gatewayCmd() { cfg.WorkspacePath(), cfg.Sentinel.IntervalSec, cfg.Sentinel.AutoHeal, - func(message string) { - if cfg.Sentinel.NotifyChannel != "" && cfg.Sentinel.NotifyChatID != "" { - msgBus.PublishOutbound(bus.OutboundMessage{ - Channel: cfg.Sentinel.NotifyChannel, - ChatID: cfg.Sentinel.NotifyChatID, - Content: "[Sentinel] " + message, - }) - } - }, + buildSentinelAlertHandler(cfg, msgBus), ) ctx, cancel := context.WithCancel(context.Background()) @@ -421,15 +413,7 @@ func gatewayCmd() { newCfg.WorkspacePath(), newCfg.Sentinel.IntervalSec, newCfg.Sentinel.AutoHeal, - func(message string) { - if newCfg.Sentinel.NotifyChannel != "" && newCfg.Sentinel.NotifyChatID != "" { - msgBus.PublishOutbound(bus.OutboundMessage{ - Channel: newCfg.Sentinel.NotifyChannel, - ChatID: newCfg.Sentinel.NotifyChatID, - Content: "[Sentinel] " + message, - }) - } - }, + buildSentinelAlertHandler(newCfg, msgBus), ) if newCfg.Sentinel.Enabled { sentinelService.SetManager(channelManager) @@ -470,15 +454,7 @@ func gatewayCmd() { newCfg.WorkspacePath(), newCfg.Sentinel.IntervalSec, newCfg.Sentinel.AutoHeal, - func(message string) { - if newCfg.Sentinel.NotifyChannel != "" && newCfg.Sentinel.NotifyChatID != "" { - msgBus.PublishOutbound(bus.OutboundMessage{ - Channel: newCfg.Sentinel.NotifyChannel, - ChatID: newCfg.Sentinel.NotifyChatID, - Content: "[Sentinel] " + message, - }) - } - }, + buildSentinelAlertHandler(newCfg, msgBus), ) if newCfg.Sentinel.Enabled { sentinelService.Start() diff --git a/cmd/sentinel_notify.go b/cmd/sentinel_notify.go new file mode 100644 index 0000000..3fc781b --- /dev/null +++ b/cmd/sentinel_notify.go @@ -0,0 +1,84 @@ +package main + +import ( + "bytes" + "encoding/json" + "net/http" + "strings" + "time" + + "github.com/YspCoder/clawgo/pkg/bus" + "github.com/YspCoder/clawgo/pkg/config" + "github.com/YspCoder/clawgo/pkg/logger" +) + +type sentinelWebhookPayload struct { + Source string `json:"source"` + Level string `json:"level"` + Message string `json:"message"` + Timestamp string `json:"timestamp"` +} + +func buildSentinelAlertHandler(cfg *config.Config, msgBus *bus.MessageBus) func(string) { + return func(message string) { + if cfg == nil { + return + } + sendSentinelChannelAlert(cfg, msgBus, message) + sendSentinelWebhookAlert(cfg, message) + } +} + +func sendSentinelChannelAlert(cfg *config.Config, msgBus *bus.MessageBus, message string) { + if cfg == nil || msgBus == nil { + return + } + if strings.TrimSpace(cfg.Sentinel.NotifyChannel) == "" || strings.TrimSpace(cfg.Sentinel.NotifyChatID) == "" { + return + } + msgBus.PublishOutbound(bus.OutboundMessage{ + Channel: cfg.Sentinel.NotifyChannel, + ChatID: cfg.Sentinel.NotifyChatID, + Content: "[Sentinel] " + message, + }) +} + +func sendSentinelWebhookAlert(cfg *config.Config, message string) { + if cfg == nil { + return + } + webhookURL := strings.TrimSpace(cfg.Sentinel.WebhookURL) + if webhookURL == "" { + return + } + payload := sentinelWebhookPayload{ + Source: "sentinel", + Level: "warning", + Message: message, + Timestamp: time.Now().UTC().Format(time.RFC3339), + } + body, err := json.Marshal(payload) + if err != nil { + logger.ErrorCF("sentinel", logger.C0137, map[string]interface{}{"error": err.Error(), "target": "webhook marshal"}) + return + } + req, err := http.NewRequest(http.MethodPost, webhookURL, bytes.NewReader(body)) + if err != nil { + logger.ErrorCF("sentinel", logger.C0137, map[string]interface{}{"error": err.Error(), "target": "webhook request"}) + return + } + req.Header.Set("Content-Type", "application/json") + + resp, err := (&http.Client{Timeout: 8 * time.Second}).Do(req) + if err != nil { + logger.ErrorCF("sentinel", logger.C0137, map[string]interface{}{"error": err.Error(), "target": "webhook send"}) + return + } + defer resp.Body.Close() + if resp.StatusCode >= 300 { + logger.ErrorCF("sentinel", logger.C0137, map[string]interface{}{ + "error": "unexpected webhook status", + "status": resp.StatusCode, + }) + } +} diff --git a/cmd/sentinel_notify_test.go b/cmd/sentinel_notify_test.go new file mode 100644 index 0000000..adbd534 --- /dev/null +++ b/cmd/sentinel_notify_test.go @@ -0,0 +1,85 @@ +package main + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/YspCoder/clawgo/pkg/bus" + "github.com/YspCoder/clawgo/pkg/config" +) + +func TestBuildSentinelAlertHandlerPublishesChannelAlertAndWebhook(t *testing.T) { + t.Parallel() + + msgBus := bus.NewMessageBus() + + webhookCh := make(chan sentinelWebhookPayload, 1) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + if r.Method != http.MethodPost { + t.Fatalf("expected POST, got %s", r.Method) + } + if got := r.Header.Get("Content-Type"); got != "application/json" { + t.Fatalf("unexpected content type: %s", got) + } + var payload sentinelWebhookPayload + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + t.Fatalf("decode payload: %v", err) + } + webhookCh <- payload + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + cfg := config.DefaultConfig() + cfg.Sentinel.NotifyChannel = "telegram" + cfg.Sentinel.NotifyChatID = "chat-1" + cfg.Sentinel.WebhookURL = srv.URL + + handler := buildSentinelAlertHandler(cfg, msgBus) + handler("disk usage high") + + outCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + msg, ok := msgBus.SubscribeOutbound(outCtx) + if !ok { + t.Fatal("expected outbound channel alert") + } + if msg.Channel != "telegram" || msg.ChatID != "chat-1" { + t.Fatalf("unexpected outbound route: %+v", msg) + } + if msg.Content != "[Sentinel] disk usage high" { + t.Fatalf("unexpected outbound content: %s", msg.Content) + } + + select { + case payload := <-webhookCh: + if payload.Source != "sentinel" { + t.Fatalf("unexpected source: %s", payload.Source) + } + if payload.Level != "warning" { + t.Fatalf("unexpected level: %s", payload.Level) + } + if payload.Message != "disk usage high" { + t.Fatalf("unexpected message: %s", payload.Message) + } + if _, err := time.Parse(time.RFC3339, payload.Timestamp); err != nil { + t.Fatalf("unexpected timestamp: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("expected webhook alert") + } +} + +func TestBuildSentinelAlertHandlerSkipsEmptyTargets(t *testing.T) { + t.Parallel() + + msgBus := bus.NewMessageBus() + cfg := config.DefaultConfig() + handler := buildSentinelAlertHandler(cfg, msgBus) + handler("noop") +} diff --git a/config.example.json b/config.example.json index f653e13..5efda48 100644 --- a/config.example.json +++ b/config.example.json @@ -218,6 +218,110 @@ "timeout_sec": 90 }, "proxies": { + "codex-oauth": { + "api_base": "https://api.openai.com/v1", + "models": [], + "responses": { + "web_search_enabled": false, + "web_search_context_size": "", + "file_search_vector_store_ids": [], + "file_search_max_num_results": 0, + "include": [], + "stream_include_usage": false + }, + "supports_responses_compact": true, + "auth": "oauth", + "oauth": { + "provider": "codex", + "credential_file": "~/.clawgo/auth/codex.json", + "credential_files": ["~/.clawgo/auth/codex.json"], + "callback_port": 1455, + "refresh_scan_sec": 600, + "refresh_lead_sec": 1800 + }, + "runtime_persist": true, + "runtime_history_file": "~/.clawgo/runtime/providers/codex-oauth.json", + "runtime_history_max": 24, + "timeout_sec": 90 + }, + "gemini-oauth": { + "api_base": "https://your-openai-compatible-gateway.example.com/v1", + "models": [], + "responses": { + "web_search_enabled": false, + "web_search_context_size": "", + "file_search_vector_store_ids": [], + "file_search_max_num_results": 0, + "include": [], + "stream_include_usage": false + }, + "supports_responses_compact": true, + "auth": "oauth", + "oauth": { + "provider": "gemini", + "client_secret": "", + "credential_files": ["~/.clawgo/auth/gemini.json"], + "callback_port": 8085, + "refresh_scan_sec": 600, + "refresh_lead_sec": 1800 + }, + "runtime_persist": true, + "runtime_history_file": "~/.clawgo/runtime/providers/gemini-oauth.json", + "runtime_history_max": 24, + "timeout_sec": 90 + }, + "openai-hybrid": { + "api_key": "sk-your-primary-api-key", + "api_base": "https://api.openai.com/v1", + "models": [], + "responses": { + "web_search_enabled": false, + "web_search_context_size": "", + "file_search_vector_store_ids": [], + "file_search_max_num_results": 0, + "include": [], + "stream_include_usage": false + }, + "supports_responses_compact": true, + "auth": "hybrid", + "oauth": { + "provider": "codex", + "credential_files": ["~/.clawgo/auth/codex.json"], + "callback_port": 1455, + "hybrid_priority": "api_first", + "cooldown_sec": 900, + "refresh_scan_sec": 600, + "refresh_lead_sec": 1800 + }, + "runtime_persist": true, + "runtime_history_file": "~/.clawgo/runtime/providers/openai-hybrid.json", + "runtime_history_max": 24, + "timeout_sec": 90 + }, + "qwen-oauth": { + "api_base": "https://your-openai-compatible-gateway.example.com/v1", + "models": [], + "responses": { + "web_search_enabled": false, + "web_search_context_size": "", + "file_search_vector_store_ids": [], + "file_search_max_num_results": 0, + "include": [], + "stream_include_usage": false + }, + "supports_responses_compact": true, + "auth": "oauth", + "oauth": { + "provider": "qwen", + "credential_files": ["~/.clawgo/auth/qwen.json"], + "refresh_scan_sec": 600, + "refresh_lead_sec": 1800 + }, + "runtime_persist": true, + "runtime_history_file": "~/.clawgo/runtime/providers/qwen-oauth.json", + "runtime_history_max": 24, + "timeout_sec": 90 + }, "backup": { "api_key": "YOUR_BACKUP_PROXY_KEY", "api_base": "http://localhost:8081/v1", @@ -305,5 +409,13 @@ "filename": "clawgo.log", "max_size_mb": 20, "retention_days": 3 + }, + "sentinel": { + "enabled": true, + "interval_sec": 60, + "auto_heal": true, + "notify_channel": "", + "notify_chat_id": "", + "webhook_url": "" } } diff --git a/go.mod b/go.mod index b61dc0e..3072398 100644 --- a/go.mod +++ b/go.mod @@ -78,6 +78,7 @@ require ( github.com/pion/stun/v3 v3.1.1 // indirect github.com/pion/transport/v4 v4.0.1 // indirect github.com/pion/turn/v4 v4.1.4 // indirect + github.com/refraction-networking/utls v1.8.2 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/rs/zerolog v1.34.0 // indirect diff --git a/go.sum b/go.sum index 3fd50d8..d863cc6 100644 --- a/go.sum +++ b/go.sum @@ -201,6 +201,8 @@ github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsK github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo= +github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 5d5e996..9b9d9f2 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -18,6 +18,7 @@ import ( "regexp" "runtime" "sort" + "strconv" "strings" "sync" "time" @@ -772,6 +773,23 @@ func (al *AgentLoop) appendTaskAuditEvent(taskID string, msg bus.InboundMessage, "provider": al.getSessionProvider(msg.SessionKey), "model": al.model, } + if msg.Metadata != nil { + if v := strings.TrimSpace(msg.Metadata["context_extra_chars"]); v != "" { + if n, err := strconv.Atoi(v); err == nil { + row["context_extra_chars"] = n + } + } + if v := strings.TrimSpace(msg.Metadata["context_ekg_chars"]); v != "" { + if n, err := strconv.Atoi(v); err == nil { + row["context_ekg_chars"] = n + } + } + if v := strings.TrimSpace(msg.Metadata["context_memory_chars"]); v != "" { + if n, err := strconv.Atoi(v); err == nil { + row["context_memory_chars"] = n + } + } + } if al.ekg != nil { al.ekg.Record(ekg.Event{ TaskID: taskID, diff --git a/pkg/agent/loop_audit_test.go b/pkg/agent/loop_audit_test.go new file mode 100644 index 0000000..9c9bb0a --- /dev/null +++ b/pkg/agent/loop_audit_test.go @@ -0,0 +1,48 @@ +package agent + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/YspCoder/clawgo/pkg/bus" +) + +func TestAppendTaskAuditEventPersistsContextCharStats(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + al := &AgentLoop{workspace: workspace} + msg := bus.InboundMessage{ + Channel: "chat", + SessionKey: "s1", + Content: "Task Context:\nEKG: repeat_errsig=perm\nTask:\ndeploy", + Metadata: map[string]string{ + "context_extra_chars": "42", + "context_ekg_chars": "18", + "context_memory_chars": "0", + }, + } + + al.appendTaskAuditEvent("task-1", msg, "success", time.Now().Add(-time.Second), 1000, "completed", false) + + b, err := os.ReadFile(filepath.Join(workspace, "memory", "task-audit.jsonl")) + if err != nil { + t.Fatalf("read task audit: %v", err) + } + var row map[string]interface{} + if err := json.Unmarshal(b[:len(b)-1], &row); err != nil { + t.Fatalf("decode task audit row: %v", err) + } + if got := int(row["context_extra_chars"].(float64)); got != 42 { + t.Fatalf("expected context_extra_chars=42, got %d", got) + } + if got := int(row["context_ekg_chars"].(float64)); got != 18 { + t.Fatalf("expected context_ekg_chars=18, got %d", got) + } + if got := int(row["context_memory_chars"].(float64)); got != 0 { + t.Fatalf("expected context_memory_chars=0, got %d", got) + } +} diff --git a/pkg/agent/session_planner.go b/pkg/agent/session_planner.go index 48c2a55..1bf8321 100644 --- a/pkg/agent/session_planner.go +++ b/pkg/agent/session_planner.go @@ -22,6 +22,7 @@ import ( type plannedTask struct { Index int + Total int Content string ResourceKeys []string } @@ -82,6 +83,7 @@ func (al *AgentLoop) planSessionTasks(ctx context.Context, msg bus.InboundMessag } out = append(out, plannedTask{ Index: i + 1, + Total: 0, Content: content, ResourceKeys: scheduling.DeriveResourceKeys(content), }) @@ -93,6 +95,9 @@ func (al *AgentLoop) planSessionTasks(ctx context.Context, msg bus.InboundMessag out[0].Content = base out[0].ResourceKeys = scheduling.DeriveResourceKeys(base) } + for i := range out { + out[i].Total = len(out) + } return out } @@ -254,6 +259,7 @@ func splitPlannedSegments(content string) []string { func (al *AgentLoop) runPlannedTasks(ctx context.Context, msg bus.InboundMessage, tasks []plannedTask) (string, error) { results := make([]plannedTaskResult, len(tasks)) + enrichedContent := al.enrichPlannedTaskContents(ctx, tasks) var wg sync.WaitGroup var progressMu sync.Mutex completed := 0 @@ -265,7 +271,8 @@ func (al *AgentLoop) runPlannedTasks(ctx context.Context, msg bus.InboundMessage go func(index int, t plannedTask) { defer wg.Done() subMsg := msg - subMsg.Content = al.enrichTaskContentWithMemoryAndEKG(ctx, t) + enriched := enrichedContent[index] + subMsg.Content = enriched.content subMsg.Metadata = cloneMetadata(msg.Metadata) if subMsg.Metadata == nil { subMsg.Metadata = map[string]string{} @@ -273,6 +280,15 @@ func (al *AgentLoop) runPlannedTasks(ctx context.Context, msg bus.InboundMessage subMsg.Metadata["resource_keys"] = strings.Join(t.ResourceKeys, ",") subMsg.Metadata["planned_task_index"] = fmt.Sprintf("%d", t.Index) subMsg.Metadata["planned_task_total"] = fmt.Sprintf("%d", len(tasks)) + if enriched.extraChars > 0 { + subMsg.Metadata["context_extra_chars"] = fmt.Sprintf("%d", enriched.extraChars) + } + if enriched.ekgChars > 0 { + subMsg.Metadata["context_ekg_chars"] = fmt.Sprintf("%d", enriched.ekgChars) + } + if enriched.memoryChars > 0 { + subMsg.Metadata["context_memory_chars"] = fmt.Sprintf("%d", enriched.memoryChars) + } out, err := al.processMessage(ctx, subMsg) res := plannedTaskResult{Index: index, Task: t, Output: strings.TrimSpace(out), Err: err} if err != nil { @@ -434,35 +450,70 @@ func summarizePlannedTaskProgressBody(body string, maxLines, maxChars int) strin return joined } -func (al *AgentLoop) enrichTaskContentWithMemoryAndEKG(ctx context.Context, task plannedTask) string { - base := strings.TrimSpace(task.Content) - if base == "" { - return base - } - hints := make([]string, 0, 2) - if mem := al.memoryHintForTask(ctx, task); mem != "" { - hints = append(hints, "Memory:\n"+mem) - } - if risk := al.ekgHintForTask(task); risk != "" { - hints = append(hints, "EKG:\n"+risk) - } - if len(hints) == 0 { - return base - } - return strings.TrimSpace( - "Task Context (use it as constraints, avoid repeating known failures):\n" + - strings.Join(hints, "\n\n") + - "\n\nTask:\n" + base, - ) +type taskPromptHints struct { + ekg string + memory string } -func (al *AgentLoop) memoryHintForTask(ctx context.Context, task plannedTask) string { +type plannedTaskPrompt struct { + content string + extraChars int + ekgChars int + memoryChars int +} + +func (al *AgentLoop) enrichPlannedTaskContents(ctx context.Context, tasks []plannedTask) []plannedTaskPrompt { + out := make([]plannedTaskPrompt, len(tasks)) + seen := make(map[string]struct{}, len(tasks)*2) + remainingBudget := plannedTaskContextBudget(len(tasks)) + for i, task := range tasks { + hints := al.collectTaskPromptHints(ctx, task) + if len(tasks) > 1 { + dedupeTaskPromptHints(&hints, seen) + applyPromptBudget(&hints, &remainingBudget) + } + prompt := buildPlannedTaskPrompt(task.Content, hints) + if len(tasks) > 1 && prompt.extraChars > 0 { + remainingBudget -= prompt.extraChars + if remainingBudget < 0 { + remainingBudget = 0 + } + } + out[i] = prompt + } + return out +} + +func (al *AgentLoop) enrichTaskContentWithMemoryAndEKG(ctx context.Context, task plannedTask) string { + return buildPlannedTaskPrompt(task.Content, al.collectTaskPromptHints(ctx, task)).content +} + +func (al *AgentLoop) collectTaskPromptHints(ctx context.Context, task plannedTask) taskPromptHints { + hints := taskPromptHints{} + if risk := al.ekgHintForTask(task); risk != "" { + hints.ekg = risk + hints.memory = al.memoryHintForTask(ctx, task, true) + return hints + } + hints.memory = al.memoryHintForTask(ctx, task, false) + return hints +} + +func (al *AgentLoop) memoryHintForTask(ctx context.Context, task plannedTask, hasEKG bool) string { if al == nil || al.tools == nil { return "" } + maxResults := 1 + maxChars := 360 + if task.Total > 1 { + maxChars = 220 + } + if hasEKG { + maxChars = 160 + } args := map[string]interface{}{ "query": task.Content, - "maxResults": 2, + "maxResults": maxResults, } if ns := memoryNamespaceFromContext(ctx); ns != "main" { args["namespace"] = ns @@ -475,7 +526,7 @@ func (al *AgentLoop) memoryHintForTask(ctx context.Context, task plannedTask) st if txt == "" || strings.HasPrefix(strings.ToLower(txt), "no memory found") { return "" } - return truncate(txt, 1200) + return compactMemoryHint(txt, maxChars) } func (al *AgentLoop) ekgHintForTask(task plannedTask) string { @@ -499,20 +550,27 @@ func (al *AgentLoop) ekgHintForTask(task plannedTask) string { if !advice.ShouldEscalate { return "" } - reasons := strings.Join(advice.Reason, ", ") - if strings.TrimSpace(reasons) == "" { - reasons = "repeated error signature" + parts := []string{ + fmt.Sprintf("repeat_errsig=%s", truncate(errSig, 72)), + fmt.Sprintf("backoff=%ds", advice.RetryBackoffSec), } - return fmt.Sprintf("Related repeated error signature detected (%s). Suggested retry backoff: %ds. Last error: %s", - errSig, advice.RetryBackoffSec, truncate(strings.TrimSpace(evt.Log), 240)) + if evt.Preview != "" { + parts = append(parts, "related_task="+truncate(strings.TrimSpace(evt.Preview), 96)) + } + if len(advice.Reason) > 0 { + parts = append(parts, "reason="+truncate(strings.Join(advice.Reason, "+"), 64)) + } + return strings.Join(parts, "; ") } type taskAuditErrorEvent struct { - TaskID string - Source string - Channel string - Log string - Preview string + TaskID string + Source string + Channel string + Log string + Preview string + MatchScore int + MatchRatio float64 } func (al *AgentLoop) findRecentRelatedErrorEvent(taskContent string) (taskAuditErrorEvent, bool) { @@ -529,6 +587,7 @@ func (al *AgentLoop) findRecentRelatedErrorEvent(taskContent string) (taskAuditE } var best taskAuditErrorEvent bestScore := 0 + bestRatio := 0.0 s := bufio.NewScanner(f) for s.Scan() { @@ -548,17 +607,25 @@ func (al *AgentLoop) findRecentRelatedErrorEvent(taskContent string) (taskAuditE continue } preview := strings.TrimSpace(fmt.Sprintf("%v", row["input_preview"])) - score := overlapScore(kw, tokenizeTaskText(preview)) - if score < 1 || score < bestScore { + previewKW := tokenizeTaskText(preview) + score := overlapScore(kw, previewKW) + ratio := overlapRatio(kw, previewKW, score) + if !isStrongTaskMatch(score, ratio) { + continue + } + if score < bestScore || (score == bestScore && ratio < bestRatio) { continue } bestScore = score + bestRatio = ratio best = taskAuditErrorEvent{ - TaskID: strings.TrimSpace(fmt.Sprintf("%v", row["task_id"])), - Source: strings.TrimSpace(fmt.Sprintf("%v", row["source"])), - Channel: strings.TrimSpace(fmt.Sprintf("%v", row["channel"])), - Log: logText, - Preview: preview, + TaskID: strings.TrimSpace(fmt.Sprintf("%v", row["task_id"])), + Source: strings.TrimSpace(fmt.Sprintf("%v", row["source"])), + Channel: strings.TrimSpace(fmt.Sprintf("%v", row["channel"])), + Log: logText, + Preview: preview, + MatchScore: score, + MatchRatio: ratio, } } if bestScore == 0 || strings.TrimSpace(best.TaskID) == "" { @@ -597,6 +664,162 @@ func overlapScore(a, b []string) int { return score } +func overlapRatio(a, b []string, score int) float64 { + if score <= 0 || len(a) == 0 || len(b) == 0 { + return 0 + } + shorter := len(a) + if len(b) < shorter { + shorter = len(b) + } + if shorter <= 0 { + return 0 + } + return float64(score) / float64(shorter) +} + +func isStrongTaskMatch(score int, ratio float64) bool { + if score >= 4 { + return true + } + if score < 2 { + return false + } + return ratio >= 0.35 +} + +func compactMemoryHint(raw string, maxChars int) string { + raw = strings.ReplaceAll(raw, "\r\n", "\n") + lines := strings.Split(raw, "\n") + parts := make([]string, 0, len(lines)) + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + lower := strings.ToLower(line) + if strings.HasPrefix(lower, "found ") { + continue + } + if strings.HasPrefix(lower, "source: ") { + line = strings.TrimSpace(strings.TrimPrefix(line, "Source: ")) + if line != "" { + parts = append(parts, "src="+line) + } + continue + } + parts = append(parts, line) + if len(parts) >= 2 { + break + } + } + if len(parts) == 0 { + return "" + } + return truncate(strings.Join(parts, " | "), maxChars) +} + +func renderTaskPromptWithHints(taskContent string, hints taskPromptHints) string { + return buildPlannedTaskPrompt(taskContent, hints).content +} + +func buildPlannedTaskPrompt(taskContent string, hints taskPromptHints) plannedTaskPrompt { + base := strings.TrimSpace(taskContent) + if base == "" { + return plannedTaskPrompt{} + } + if hints.ekg == "" && hints.memory == "" { + return plannedTaskPrompt{content: base} + } + lines := make([]string, 0, 4) + lines = append(lines, "Task Context:") + if hints.ekg != "" { + lines = append(lines, "EKG: "+hints.ekg) + } + if hints.memory != "" { + lines = append(lines, "Memory: "+hints.memory) + } + lines = append(lines, "Task:", base) + content := strings.Join(lines, "\n") + return plannedTaskPrompt{ + content: content, + extraChars: maxInt(len(content)-len(base), 0), + ekgChars: len(hints.ekg), + memoryChars: len(hints.memory), + } +} + +func plannedTaskContextBudget(taskCount int) int { + if taskCount <= 1 { + return 1 << 30 + } + return 320 +} + +func applyPromptBudget(hints *taskPromptHints, remaining *int) { + if hints == nil || remaining == nil { + return + } + if *remaining <= 0 { + hints.ekg = "" + hints.memory = "" + return + } + needed := estimateHintChars(*hints) + if needed <= *remaining { + return + } + hints.memory = "" + needed = estimateHintChars(*hints) + if needed <= *remaining { + return + } + hints.ekg = "" +} + +func estimateHintChars(hints taskPromptHints) int { + total := 0 + if hints.ekg != "" { + total += len("Task Context:\nEKG: \nTask:\n") + len(hints.ekg) + } + if hints.memory != "" { + total += len("Memory: \n") + len(hints.memory) + if hints.ekg == "" { + total += len("Task Context:\nTask:\n") + } + } + return total +} + +func dedupeTaskPromptHints(hints *taskPromptHints, seen map[string]struct{}) { + if hints == nil || seen == nil { + return + } + if hints.ekg != "" { + key := "ekg:" + hints.ekg + if _, ok := seen[key]; ok { + hints.ekg = "" + } else { + seen[key] = struct{}{} + } + } + if hints.memory != "" { + key := "memory:" + hints.memory + if _, ok := seen[key]; ok { + hints.memory = "" + } else { + seen[key] = struct{}{} + } + } +} + +func maxInt(a, b int) int { + if a > b { + return a + } + return b +} + func cloneMetadata(m map[string]string) map[string]string { if len(m) == 0 { return nil diff --git a/pkg/agent/session_planner_test.go b/pkg/agent/session_planner_test.go index 900a7b0..d9ca3db 100644 --- a/pkg/agent/session_planner_test.go +++ b/pkg/agent/session_planner_test.go @@ -1,8 +1,13 @@ package agent import ( + "fmt" + "os" + "path/filepath" "strings" "testing" + + "github.com/YspCoder/clawgo/pkg/ekg" ) func TestSummarizePlannedTaskProgressBodyPreservesUsefulLines(t *testing.T) { @@ -21,3 +26,164 @@ func TestSummarizePlannedTaskProgressBodyPreservesUsefulLines(t *testing.T) { t.Fatalf("expected multi-line formatting, got:\n%s", out) } } + +func TestEKGHintForTaskRequiresStrongMatchAndStaysCompact(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + memoryDir := filepath.Join(workspace, "memory") + if err := os.MkdirAll(memoryDir, 0o755); err != nil { + t.Fatalf("mkdir memory: %v", err) + } + logText := "open /srv/app/config.yaml: permission denied after deploy 42" + taskAudit := []string{ + fmt.Sprintf(`{"task_id":"task-1","status":"error","source":"planner","channel":"chat","input_preview":"check nginx logs quickly","log":"%s"}`, logText), + fmt.Sprintf(`{"task_id":"task-2","status":"error","source":"planner","channel":"chat","input_preview":"deploy config service restart on cluster","log":"%s"}`, logText), + } + if err := os.WriteFile(filepath.Join(memoryDir, "task-audit.jsonl"), []byte(strings.Join(taskAudit, "\n")+"\n"), 0o644); err != nil { + t.Fatalf("write task audit: %v", err) + } + errSig := ekg.NormalizeErrorSignature(logText) + ekgEvents := []string{ + fmt.Sprintf(`{"task_id":"task-2","status":"error","errsig":"%s","log":"%s"}`, errSig, logText), + fmt.Sprintf(`{"task_id":"task-2","status":"error","errsig":"%s","log":"%s"}`, errSig, logText), + fmt.Sprintf(`{"task_id":"task-2","status":"error","errsig":"%s","log":"%s"}`, errSig, logText), + } + if err := os.WriteFile(filepath.Join(memoryDir, "ekg-events.jsonl"), []byte(strings.Join(ekgEvents, "\n")+"\n"), 0o644); err != nil { + t.Fatalf("write ekg events: %v", err) + } + + al := &AgentLoop{workspace: workspace, ekg: ekg.New(workspace)} + hint := al.ekgHintForTask(plannedTask{Content: "deploy config service restart after rollout"}) + if hint == "" { + t.Fatalf("expected compact ekg hint") + } + if !strings.Contains(hint, "repeat_errsig=") || !strings.Contains(hint, "backoff=300s") { + t.Fatalf("expected compact fields, got: %s", hint) + } + if strings.Contains(strings.ToLower(hint), "last error") || strings.Contains(hint, logText) { + t.Fatalf("expected raw error log to be omitted, got: %s", hint) + } +} + +func TestEKGHintForTaskSkipsWeakTaskMatch(t *testing.T) { + t.Parallel() + + workspace := t.TempDir() + memoryDir := filepath.Join(workspace, "memory") + if err := os.MkdirAll(memoryDir, 0o755); err != nil { + t.Fatalf("mkdir memory: %v", err) + } + logText := "dial tcp 10.0.0.8:443: i/o timeout" + taskAudit := `{"task_id":"task-3","status":"error","source":"planner","channel":"chat","input_preview":"investigate cache timeout","log":"dial tcp 10.0.0.8:443: i/o timeout"}` + if err := os.WriteFile(filepath.Join(memoryDir, "task-audit.jsonl"), []byte(taskAudit+"\n"), 0o644); err != nil { + t.Fatalf("write task audit: %v", err) + } + errSig := ekg.NormalizeErrorSignature(logText) + ekgEvents := []string{ + fmt.Sprintf(`{"task_id":"task-3","status":"error","errsig":"%s","log":"%s"}`, errSig, logText), + fmt.Sprintf(`{"task_id":"task-3","status":"error","errsig":"%s","log":"%s"}`, errSig, logText), + fmt.Sprintf(`{"task_id":"task-3","status":"error","errsig":"%s","log":"%s"}`, errSig, logText), + } + if err := os.WriteFile(filepath.Join(memoryDir, "ekg-events.jsonl"), []byte(strings.Join(ekgEvents, "\n")+"\n"), 0o644); err != nil { + t.Fatalf("write ekg events: %v", err) + } + + al := &AgentLoop{workspace: workspace, ekg: ekg.New(workspace)} + hint := al.ekgHintForTask(plannedTask{Content: "cache rebuild"}) + if hint != "" { + t.Fatalf("expected weak match to skip ekg hint, got: %s", hint) + } +} + +func TestCompactMemoryHintDropsVerboseScaffolding(t *testing.T) { + t.Parallel() + + raw := "Found 1 memories for 'deploy config' (namespace=main):\n\nSource: memory/2026-03-10.md#L12-L18\nDeploy must restart config service after updating the file.\nValidate permissions before rollout.\n\n" + got := compactMemoryHint(raw, 120) + if strings.Contains(strings.ToLower(got), "found 1 memories") { + t.Fatalf("expected summary header removed, got: %s", got) + } + if !strings.Contains(got, "src=memory/2026-03-10.md#L12-L18") { + t.Fatalf("expected source to remain compactly, got: %s", got) + } + if !strings.Contains(got, "Deploy must restart config service") { + t.Fatalf("expected main snippet retained, got: %s", got) + } +} + +func TestDedupeTaskPromptHintsDropsRepeatedContext(t *testing.T) { + t.Parallel() + + seen := map[string]struct{}{} + first := taskPromptHints{ekg: "repeat_errsig=x; backoff=300s", memory: "src=memory/a.md#L1-L2 | restart service"} + second := taskPromptHints{ekg: "repeat_errsig=x; backoff=300s", memory: "src=memory/a.md#L1-L2 | restart service"} + + dedupeTaskPromptHints(&first, seen) + dedupeTaskPromptHints(&second, seen) + + if first.ekg == "" || first.memory == "" { + t.Fatalf("expected first hint set to remain: %+v", first) + } + if second.ekg != "" || second.memory != "" { + t.Fatalf("expected repeated hints removed: %+v", second) + } +} + +func TestRenderTaskPromptWithHintsKeepsCompactShape(t *testing.T) { + t.Parallel() + + got := renderTaskPromptWithHints("deploy config service", taskPromptHints{ + ekg: "repeat_errsig=perm; backoff=300s", + memory: "src=memory/x.md#L1-L2 | restart after change", + }) + if !strings.Contains(got, "Task Context:\nEKG: repeat_errsig=perm; backoff=300s\nMemory: src=memory/x.md#L1-L2 | restart after change\nTask:\ndeploy config service") { + t.Fatalf("unexpected prompt shape: %s", got) + } +} + +func TestBuildPlannedTaskPromptTracksExtraChars(t *testing.T) { + t.Parallel() + + prompt := buildPlannedTaskPrompt("deploy config service", taskPromptHints{ + ekg: "repeat_errsig=perm; backoff=300s", + memory: "src=memory/x.md#L1-L2 | restart after change", + }) + if prompt.content == "" { + t.Fatalf("expected prompt content") + } + if prompt.extraChars <= 0 || prompt.ekgChars <= 0 || prompt.memoryChars <= 0 { + t.Fatalf("expected prompt stats populated: %+v", prompt) + } +} + +func TestApplyPromptBudgetPrefersEKGOverMemory(t *testing.T) { + t.Parallel() + + hints := taskPromptHints{ + ekg: "repeat_errsig=perm; backoff=300s", + memory: "src=memory/x.md#L1-L2 | restart after change", + } + remaining := estimateHintChars(hints) - len(hints.memory) + applyPromptBudget(&hints, &remaining) + if hints.ekg == "" { + t.Fatalf("expected ekg retained") + } + if hints.memory != "" { + t.Fatalf("expected memory dropped under budget pressure: %+v", hints) + } +} + +func TestApplyPromptBudgetDropsAllWhenBudgetExhausted(t *testing.T) { + t.Parallel() + + hints := taskPromptHints{ + ekg: "repeat_errsig=perm; backoff=300s", + memory: "src=memory/x.md#L1-L2 | restart after change", + } + remaining := 8 + applyPromptBudget(&hints, &remaining) + if hints.ekg != "" || hints.memory != "" { + t.Fatalf("expected all hints removed under tiny budget: %+v", hints) + } +} diff --git a/pkg/api/server.go b/pkg/api/server.go index ea47f5d..c07ebd3 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -31,6 +31,7 @@ import ( "github.com/YspCoder/clawgo/pkg/channels" cfgpkg "github.com/YspCoder/clawgo/pkg/config" "github.com/YspCoder/clawgo/pkg/nodes" + "github.com/YspCoder/clawgo/pkg/providers" "github.com/YspCoder/clawgo/pkg/tools" "github.com/gorilla/websocket" "rsc.io/qr" @@ -73,6 +74,8 @@ type Server struct { liveSubagents map[string]*liveSubagentGroup whatsAppBridge *channels.WhatsAppBridgeService whatsAppBase string + oauthFlowMu sync.Mutex + oauthFlows map[string]*providers.OAuthPendingFlow } var nodesWebsocketUpgrader = websocket.Upgrader{ @@ -96,6 +99,7 @@ func NewServer(host string, port int, token string, mgr *nodes.Manager) *Server artifactStats: map[string]interface{}{}, liveRuntimeSubs: map[chan []byte]struct{}{}, liveSubagents: map[string]*liveSubagentGroup{}, + oauthFlows: map[string]*providers.OAuthPendingFlow{}, } } @@ -449,6 +453,12 @@ func (s *Server) Start(ctx context.Context) error { mux.HandleFunc("/webui/api/chat/live", s.handleWebUIChatLive) mux.HandleFunc("/webui/api/runtime", s.handleWebUIRuntime) mux.HandleFunc("/webui/api/version", s.handleWebUIVersion) + mux.HandleFunc("/webui/api/provider/oauth/start", s.handleWebUIProviderOAuthStart) + mux.HandleFunc("/webui/api/provider/oauth/complete", s.handleWebUIProviderOAuthComplete) + mux.HandleFunc("/webui/api/provider/oauth/import", s.handleWebUIProviderOAuthImport) + mux.HandleFunc("/webui/api/provider/oauth/accounts", s.handleWebUIProviderOAuthAccounts) + mux.HandleFunc("/webui/api/provider/runtime", s.handleWebUIProviderRuntime) + mux.HandleFunc("/webui/api/provider/runtime/summary", s.handleWebUIProviderRuntimeSummary) mux.HandleFunc("/webui/api/whatsapp/status", s.handleWebUIWhatsAppStatus) mux.HandleFunc("/webui/api/whatsapp/logout", s.handleWebUIWhatsAppLogout) mux.HandleFunc("/webui/api/whatsapp/qr.svg", s.handleWebUIWhatsAppQR) @@ -979,6 +989,499 @@ func (s *Server) handleWebUIUpload(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "path": path, "name": h.Filename}) } +func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodPost && r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var body struct { + Provider string `json:"provider"` + AccountLabel string `json:"account_label"` + ProviderConfig cfgpkg.ProviderConfig `json:"provider_config"` + } + if r.Method == http.MethodPost { + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + } else { + body.Provider = strings.TrimSpace(r.URL.Query().Get("provider")) + body.AccountLabel = strings.TrimSpace(r.URL.Query().Get("account_label")) + } + cfg, pc, err := s.resolveProviderConfig(strings.TrimSpace(body.Provider), body.ProviderConfig) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + _ = cfg + timeout := pc.TimeoutSec + if timeout <= 0 { + timeout = 90 + } + loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + flow, err := loginMgr.StartManualFlow() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + flowID := fmt.Sprintf("%d", time.Now().UnixNano()) + s.oauthFlowMu.Lock() + s.oauthFlows[flowID] = flow + s.oauthFlowMu.Unlock() + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "ok": true, + "flow_id": flowID, + "mode": flow.Mode, + "auth_url": flow.AuthURL, + "user_code": flow.UserCode, + "instructions": flow.Instructions, + "account_label": strings.TrimSpace(body.AccountLabel), + }) +} + +func (s *Server) handleWebUIProviderOAuthComplete(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var body struct { + Provider string `json:"provider"` + FlowID string `json:"flow_id"` + CallbackURL string `json:"callback_url"` + AccountLabel string `json:"account_label"` + ProviderConfig cfgpkg.ProviderConfig `json:"provider_config"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + cfg, pc, err := s.resolveProviderConfig(strings.TrimSpace(body.Provider), body.ProviderConfig) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + timeout := pc.TimeoutSec + if timeout <= 0 { + timeout = 90 + } + loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + s.oauthFlowMu.Lock() + flow := s.oauthFlows[strings.TrimSpace(body.FlowID)] + delete(s.oauthFlows, strings.TrimSpace(body.FlowID)) + s.oauthFlowMu.Unlock() + if flow == nil { + http.Error(w, "oauth flow not found", http.StatusBadRequest) + return + } + session, models, err := loginMgr.CompleteManualFlowWithOptions(r.Context(), pc.APIBase, flow, body.CallbackURL, providers.OAuthLoginOptions{ + AccountLabel: strings.TrimSpace(body.AccountLabel), + }) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if len(models) > 0 { + pc.Models = models + } + if session.CredentialFile != "" { + pc.OAuth.CredentialFile = session.CredentialFile + pc.OAuth.CredentialFiles = appendUniqueStrings(pc.OAuth.CredentialFiles, session.CredentialFile) + } + if err := s.saveProviderConfig(cfg, body.Provider, pc); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "ok": true, + "account": session.Email, + "credential_file": session.CredentialFile, + "models": models, + }) +} + +func (s *Server) handleWebUIProviderOAuthImport(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + if err := r.ParseMultipartForm(16 << 20); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + providerName := strings.TrimSpace(r.FormValue("provider")) + accountLabel := strings.TrimSpace(r.FormValue("account_label")) + inlineCfgRaw := strings.TrimSpace(r.FormValue("provider_config")) + var inlineCfg cfgpkg.ProviderConfig + if inlineCfgRaw != "" { + if err := json.Unmarshal([]byte(inlineCfgRaw), &inlineCfg); err != nil { + http.Error(w, "invalid provider_config", http.StatusBadRequest) + return + } + } + cfg, pc, err := s.resolveProviderConfig(providerName, inlineCfg) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + file, header, err := r.FormFile("file") + if err != nil { + http.Error(w, "file required", http.StatusBadRequest) + return + } + defer file.Close() + data, err := io.ReadAll(file) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + timeout := pc.TimeoutSec + if timeout <= 0 { + timeout = 90 + } + loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + session, models, err := loginMgr.ImportAuthJSONWithOptions(r.Context(), pc.APIBase, header.Filename, data, providers.OAuthLoginOptions{ + AccountLabel: accountLabel, + }) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if len(models) > 0 { + pc.Models = models + } + if session.CredentialFile != "" { + pc.OAuth.CredentialFile = session.CredentialFile + pc.OAuth.CredentialFiles = appendUniqueStrings(pc.OAuth.CredentialFiles, session.CredentialFile) + } + if err := s.saveProviderConfig(cfg, providerName, pc); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "ok": true, + "account": session.Email, + "credential_file": session.CredentialFile, + "models": models, + }) +} + +func (s *Server) handleWebUIProviderOAuthAccounts(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + providerName := strings.TrimSpace(r.URL.Query().Get("provider")) + cfg, pc, err := s.loadProviderConfig(providerName) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + _ = cfg + timeout := pc.TimeoutSec + if timeout <= 0 { + timeout = 90 + } + loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + switch r.Method { + case http.MethodGet: + accounts, err := loginMgr.ListAccounts() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "accounts": accounts}) + case http.MethodPost: + var body struct { + Action string `json:"action"` + CredentialFile string `json:"credential_file"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + switch strings.ToLower(strings.TrimSpace(body.Action)) { + case "refresh": + account, err := loginMgr.RefreshAccount(r.Context(), body.CredentialFile) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "account": account}) + case "delete": + if err := loginMgr.DeleteAccount(body.CredentialFile); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + pc.OAuth.CredentialFiles = removeStringItem(pc.OAuth.CredentialFiles, body.CredentialFile) + if strings.TrimSpace(pc.OAuth.CredentialFile) == strings.TrimSpace(body.CredentialFile) { + pc.OAuth.CredentialFile = "" + if len(pc.OAuth.CredentialFiles) > 0 { + pc.OAuth.CredentialFile = pc.OAuth.CredentialFiles[0] + } + } + if err := s.saveProviderConfig(cfg, providerName, pc); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "deleted": true}) + case "clear_cooldown": + if err := loginMgr.ClearCooldown(body.CredentialFile); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "cleared": true}) + default: + http.Error(w, "unsupported action", http.StatusBadRequest) + } + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleWebUIProviderRuntime(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method == http.MethodGet { + cfg, err := cfgpkg.LoadConfig(s.configPath) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + query := providers.ProviderRuntimeQuery{ + Provider: strings.TrimSpace(r.URL.Query().Get("provider")), + EventKind: strings.TrimSpace(r.URL.Query().Get("kind")), + Reason: strings.TrimSpace(r.URL.Query().Get("reason")), + Target: strings.TrimSpace(r.URL.Query().Get("target")), + Sort: strings.TrimSpace(r.URL.Query().Get("sort")), + ChangesOnly: strings.EqualFold(strings.TrimSpace(r.URL.Query().Get("changes_only")), "true"), + } + if secs, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("window_sec"))); secs > 0 { + query.Window = time.Duration(secs) * time.Second + } + if limit, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("limit"))); limit > 0 { + query.Limit = limit + } + if cursor, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("cursor"))); cursor >= 0 { + query.Cursor = cursor + } + if healthBelow, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("health_below"))); healthBelow > 0 { + query.HealthBelow = healthBelow + } + if secs, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("cooldown_until_before_sec"))); secs > 0 { + query.CooldownBefore = time.Now().Add(time.Duration(secs) * time.Second) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "ok": true, + "view": providers.GetProviderRuntimeView(cfg, query), + }) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var body struct { + Provider string `json:"provider"` + Action string `json:"action"` + OnlyExpiring bool `json:"only_expiring"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + switch strings.ToLower(strings.TrimSpace(body.Action)) { + case "clear_api_cooldown": + providers.ClearProviderAPICooldown(strings.TrimSpace(body.Provider)) + _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "cleared": true}) + case "clear_history": + providers.ClearProviderRuntimeHistory(strings.TrimSpace(body.Provider)) + _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "cleared": true}) + case "refresh_now": + cfg, err := cfgpkg.LoadConfig(s.configPath) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + result, err := providers.RefreshProviderRuntimeNow(cfg, strings.TrimSpace(body.Provider), body.OnlyExpiring) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "refreshed": true, "result": result}) + case "rerank": + cfg, err := cfgpkg.LoadConfig(s.configPath) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + order, err := providers.RerankProviderRuntime(cfg, strings.TrimSpace(body.Provider)) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "reranked": true, "candidate_order": order}) + default: + http.Error(w, "unsupported action", http.StatusBadRequest) + } +} + +func (s *Server) handleWebUIProviderRuntimeSummary(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + cfg, err := cfgpkg.LoadConfig(s.configPath) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + query := providers.ProviderRuntimeQuery{ + Provider: strings.TrimSpace(r.URL.Query().Get("provider")), + Reason: strings.TrimSpace(r.URL.Query().Get("reason")), + Target: strings.TrimSpace(r.URL.Query().Get("target")), + } + if secs, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("window_sec"))); secs > 0 { + query.Window = time.Duration(secs) * time.Second + } + if healthBelow, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("health_below"))); healthBelow > 0 { + query.HealthBelow = healthBelow + } + if query.HealthBelow <= 0 { + query.HealthBelow = 50 + } + if secs, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("cooldown_until_before_sec"))); secs > 0 { + query.CooldownBefore = time.Now().Add(time.Duration(secs) * time.Second) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "ok": true, + "summary": providers.GetProviderRuntimeSummary(cfg, query), + }) +} + +func (s *Server) loadProviderConfig(name string) (*cfgpkg.Config, cfgpkg.ProviderConfig, error) { + if strings.TrimSpace(s.configPath) == "" { + return nil, cfgpkg.ProviderConfig{}, fmt.Errorf("config path not set") + } + cfg, err := cfgpkg.LoadConfig(s.configPath) + if err != nil { + return nil, cfgpkg.ProviderConfig{}, err + } + providerName := strings.TrimSpace(name) + if providerName == "" || providerName == "proxy" { + return cfg, cfg.Providers.Proxy, nil + } + pc, ok := cfg.Providers.Proxies[providerName] + if !ok { + return nil, cfgpkg.ProviderConfig{}, fmt.Errorf("provider %q not found", providerName) + } + return cfg, pc, nil +} + +func (s *Server) resolveProviderConfig(name string, inline cfgpkg.ProviderConfig) (*cfgpkg.Config, cfgpkg.ProviderConfig, error) { + if hasInlineProviderConfig(inline) { + cfg, err := cfgpkg.LoadConfig(s.configPath) + if err != nil { + return nil, cfgpkg.ProviderConfig{}, err + } + return cfg, inline, nil + } + return s.loadProviderConfig(name) +} + +func hasInlineProviderConfig(pc cfgpkg.ProviderConfig) bool { + return strings.TrimSpace(pc.APIBase) != "" || + strings.TrimSpace(pc.APIKey) != "" || + len(pc.Models) > 0 || + strings.TrimSpace(pc.Auth) != "" || + strings.TrimSpace(pc.OAuth.Provider) != "" +} + +func (s *Server) saveProviderConfig(cfg *cfgpkg.Config, name string, pc cfgpkg.ProviderConfig) error { + if cfg == nil { + return fmt.Errorf("config is nil") + } + providerName := strings.TrimSpace(name) + if providerName == "" || providerName == "proxy" { + cfg.Providers.Proxy = pc + } else { + if cfg.Providers.Proxies == nil { + cfg.Providers.Proxies = map[string]cfgpkg.ProviderConfig{} + } + cfg.Providers.Proxies[providerName] = pc + } + if err := cfgpkg.SaveConfig(s.configPath, cfg); err != nil { + return err + } + if s.onConfigAfter != nil { + s.onConfigAfter() + } else { + _ = requestSelfReloadSignal() + } + return nil +} + +func appendUniqueStrings(values []string, item string) []string { + item = strings.TrimSpace(item) + if item == "" { + return values + } + for _, value := range values { + if strings.TrimSpace(value) == item { + return values + } + } + return append(values, item) +} + +func removeStringItem(values []string, item string) []string { + item = strings.TrimSpace(item) + if item == "" { + return values + } + out := make([]string, 0, len(values)) + for _, value := range values { + if strings.TrimSpace(value) == item { + continue + } + out = append(out, value) + } + return out +} + func (s *Server) handleWebUIChat(w http.ResponseWriter, r *http.Request) { if !s.checkAuth(r) { http.Error(w, "unauthorized", http.StatusUnauthorized) @@ -1485,6 +1988,15 @@ func (s *Server) handleWebUIRuntime(w http.ResponseWriter, r *http.Request) { } func (s *Server) buildWebUIRuntimeSnapshot(ctx context.Context) map[string]interface{} { + var providerPayload map[string]interface{} + if strings.TrimSpace(s.configPath) != "" { + if cfg, err := cfgpkg.LoadConfig(strings.TrimSpace(s.configPath)); err == nil { + providerPayload = providers.GetProviderRuntimeSnapshot(cfg) + } + } + if providerPayload == nil { + providerPayload = map[string]interface{}{"items": []interface{}{}} + } return map[string]interface{}{ "version": s.webUIVersionPayload(), "nodes": s.webUINodesPayload(ctx), @@ -1492,6 +2004,7 @@ func (s *Server) buildWebUIRuntimeSnapshot(ctx context.Context) map[string]inter "task_queue": s.webUITaskQueuePayload(false), "ekg": s.webUIEKGSummaryPayload("24h"), "subagents": s.webUISubagentsRuntimePayload(ctx), + "providers": providerPayload, } } diff --git a/pkg/config/config.go b/pkg/config/config.go index 44a2d76..ba61f8b 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -289,9 +289,30 @@ type ProviderConfig struct { SupportsResponsesCompact bool `json:"supports_responses_compact" env:"CLAWGO_PROVIDERS_{{.Name}}_SUPPORTS_RESPONSES_COMPACT"` Auth string `json:"auth" env:"CLAWGO_PROVIDERS_{{.Name}}_AUTH"` TimeoutSec int `json:"timeout_sec" env:"CLAWGO_PROVIDERS_PROXY_TIMEOUT_SEC"` + RuntimePersist bool `json:"runtime_persist,omitempty"` + RuntimeHistoryFile string `json:"runtime_history_file,omitempty"` + RuntimeHistoryMax int `json:"runtime_history_max,omitempty"` + OAuth ProviderOAuthConfig `json:"oauth,omitempty"` Responses ProviderResponsesConfig `json:"responses"` } +type ProviderOAuthConfig struct { + Provider string `json:"provider,omitempty"` + CredentialFile string `json:"credential_file,omitempty"` + CredentialFiles []string `json:"credential_files,omitempty"` + CallbackPort int `json:"callback_port,omitempty"` + ClientID string `json:"client_id,omitempty"` + ClientSecret string `json:"client_secret,omitempty"` + AuthURL string `json:"auth_url,omitempty"` + TokenURL string `json:"token_url,omitempty"` + RedirectURL string `json:"redirect_url,omitempty"` + Scopes []string `json:"scopes,omitempty"` + HybridPriority string `json:"hybrid_priority,omitempty"` + CooldownSec int `json:"cooldown_sec,omitempty"` + RefreshScanSec int `json:"refresh_scan_sec,omitempty"` + RefreshLeadSec int `json:"refresh_lead_sec,omitempty"` +} + type ProviderResponsesConfig struct { WebSearchEnabled bool `json:"web_search_enabled"` WebSearchContextSize string `json:"web_search_context_size"` @@ -419,6 +440,7 @@ type SentinelConfig struct { AutoHeal bool `json:"auto_heal" env:"CLAWGO_SENTINEL_AUTO_HEAL"` NotifyChannel string `json:"notify_channel" env:"CLAWGO_SENTINEL_NOTIFY_CHANNEL"` NotifyChatID string `json:"notify_chat_id" env:"CLAWGO_SENTINEL_NOTIFY_CHAT_ID"` + WebhookURL string `json:"webhook_url" env:"CLAWGO_SENTINEL_WEBHOOK_URL"` } type MemoryConfig struct { @@ -659,6 +681,7 @@ func DefaultConfig() *Config { AutoHeal: true, NotifyChannel: "", NotifyChatID: "", + WebhookURL: "", }, Memory: MemoryConfig{ Layered: true, diff --git a/pkg/config/validate.go b/pkg/config/validate.go index 9c4e2c5..b55b4a3 100644 --- a/pkg/config/validate.go +++ b/pkg/config/validate.go @@ -2,6 +2,7 @@ package config import ( "fmt" + "net/url" "path/filepath" "strings" ) @@ -203,6 +204,18 @@ func Validate(cfg *Config) []error { if cfg.Sentinel.Enabled && cfg.Sentinel.IntervalSec <= 0 { errs = append(errs, fmt.Errorf("sentinel.interval_sec must be > 0 when sentinel.enabled=true")) } + if raw := strings.TrimSpace(cfg.Sentinel.WebhookURL); raw != "" { + u, err := url.Parse(raw) + if err != nil || u == nil || u.Host == "" { + errs = append(errs, fmt.Errorf("sentinel.webhook_url must be a valid http/https URL")) + } else { + switch strings.ToLower(strings.TrimSpace(u.Scheme)) { + case "http", "https": + default: + errs = append(errs, fmt.Errorf("sentinel.webhook_url must use http or https")) + } + } + } if cfg.Memory.RecentDays <= 0 { errs = append(errs, fmt.Errorf("memory.recent_days must be > 0")) } @@ -523,15 +536,42 @@ func containsString(items []string, target string) bool { func validateProviderConfig(path string, p ProviderConfig) []error { var errs []error + authMode := strings.ToLower(strings.TrimSpace(p.Auth)) if p.APIBase == "" { errs = append(errs, fmt.Errorf("%s.api_base is required", path)) } if p.TimeoutSec <= 0 { errs = append(errs, fmt.Errorf("%s.timeout_sec must be > 0", path)) } - if len(p.Models) == 0 { + switch authMode { + case "", "bearer", "oauth", "none", "hybrid": + default: + errs = append(errs, fmt.Errorf("%s.auth must be one of: bearer, oauth, hybrid, none", path)) + } + if len(p.Models) == 0 && authMode != "oauth" && authMode != "hybrid" { errs = append(errs, fmt.Errorf("%s.models must contain at least one model", path)) } + if authMode == "oauth" && strings.TrimSpace(p.OAuth.Provider) == "" { + errs = append(errs, fmt.Errorf("%s.oauth.provider is required when auth=oauth", path)) + } + if authMode == "hybrid" { + if strings.TrimSpace(p.APIKey) == "" && strings.TrimSpace(p.OAuth.Provider) == "" { + errs = append(errs, fmt.Errorf("%s.hybrid auth requires api_key or oauth.provider", path)) + } + if strings.TrimSpace(p.OAuth.Provider) == "" { + errs = append(errs, fmt.Errorf("%s.oauth.provider is required when auth=hybrid", path)) + } + } + if p.OAuth.HybridPriority != "" { + switch strings.ToLower(strings.TrimSpace(p.OAuth.HybridPriority)) { + case "api_first", "oauth_first": + default: + errs = append(errs, fmt.Errorf("%s.oauth.hybrid_priority must be one of: api_first, oauth_first", path)) + } + } + if p.OAuth.CooldownSec < 0 { + errs = append(errs, fmt.Errorf("%s.oauth.cooldown_sec must be >= 0", path)) + } if p.Responses.WebSearchContextSize != "" { switch p.Responses.WebSearchContextSize { case "low", "medium", "high": diff --git a/pkg/config/validate_test.go b/pkg/config/validate_test.go index bd8582a..64d8c6a 100644 --- a/pkg/config/validate_test.go +++ b/pkg/config/validate_test.go @@ -1,6 +1,9 @@ package config -import "testing" +import ( + "strings" + "testing" +) func TestDefaultConfigGeneratesGatewayToken(t *testing.T) { t.Parallel() @@ -204,6 +207,30 @@ func TestValidateGatewayNodeDispatchRejectsEmptyAllowNodeKey(t *testing.T) { } } +func TestValidateSentinelWebhookURLRejectsInvalidScheme(t *testing.T) { + t.Parallel() + + cfg := DefaultConfig() + cfg.Sentinel.WebhookURL = "ftp://example.com/hook" + + if errs := Validate(cfg); len(errs) == 0 { + t.Fatalf("expected validation errors") + } +} + +func TestValidateSentinelWebhookURLAllowsHTTPS(t *testing.T) { + t.Parallel() + + cfg := DefaultConfig() + cfg.Sentinel.WebhookURL = "https://example.com/hook" + + for _, err := range Validate(cfg) { + if strings.Contains(err.Error(), "sentinel.webhook_url") { + t.Fatalf("unexpected webhook validation error: %v", err) + } + } +} + func TestDefaultConfigSetsNodeArtifactRetentionDefaults(t *testing.T) { t.Parallel() @@ -244,3 +271,106 @@ func TestValidateNodeArtifactRetentionRejectsNegativeRetainDays(t *testing.T) { t.Fatalf("expected validation errors") } } + +func TestValidateProviderOAuthAllowsEmptyModelsBeforeLogin(t *testing.T) { + t.Parallel() + + cfg := DefaultConfig() + cfg.Providers.Proxy.Auth = "oauth" + cfg.Providers.Proxy.Models = nil + cfg.Providers.Proxy.OAuth = ProviderOAuthConfig{Provider: "codex"} + + if errs := Validate(cfg); len(errs) != 0 { + t.Fatalf("expected oauth provider config to be valid before model sync, got %v", errs) + } +} + +func TestValidateProviderOAuthRequiresProviderName(t *testing.T) { + t.Parallel() + + cfg := DefaultConfig() + cfg.Providers.Proxy.Auth = "oauth" + cfg.Providers.Proxy.Models = nil + cfg.Providers.Proxy.OAuth = ProviderOAuthConfig{} + + errs := Validate(cfg) + if len(errs) == 0 { + t.Fatalf("expected validation errors") + } + found := false + for _, err := range errs { + if strings.Contains(err.Error(), "providers.proxy.oauth.provider") { + found = true + break + } + } + if !found { + t.Fatalf("expected oauth.provider validation error, got %v", errs) + } +} + +func TestValidateProviderHybridAllowsEmptyModels(t *testing.T) { + t.Parallel() + + cfg := DefaultConfig() + cfg.Providers.Proxy.Auth = "hybrid" + cfg.Providers.Proxy.APIKey = "sk-test" + cfg.Providers.Proxy.Models = nil + cfg.Providers.Proxy.OAuth = ProviderOAuthConfig{Provider: "codex"} + + if errs := Validate(cfg); len(errs) != 0 { + t.Fatalf("expected hybrid provider config to be valid before model sync, got %v", errs) + } +} + +func TestValidateProviderHybridRequiresOAuthProvider(t *testing.T) { + t.Parallel() + + cfg := DefaultConfig() + cfg.Providers.Proxy.Auth = "hybrid" + cfg.Providers.Proxy.APIKey = "sk-test" + cfg.Providers.Proxy.Models = nil + cfg.Providers.Proxy.OAuth = ProviderOAuthConfig{} + + errs := Validate(cfg) + if len(errs) == 0 { + t.Fatalf("expected validation errors") + } + found := false + for _, err := range errs { + if strings.Contains(err.Error(), "providers.proxy.oauth.provider") { + found = true + break + } + } + if !found { + t.Fatalf("expected oauth.provider validation error, got %v", errs) + } +} + +func TestValidateProviderHybridPriorityRejectsInvalidValue(t *testing.T) { + t.Parallel() + + cfg := DefaultConfig() + cfg.Providers.Proxy.Auth = "hybrid" + cfg.Providers.Proxy.APIKey = "sk-test" + cfg.Providers.Proxy.OAuth = ProviderOAuthConfig{ + Provider: "codex", + HybridPriority: "random_first", + } + + errs := Validate(cfg) + if len(errs) == 0 { + t.Fatalf("expected validation errors") + } + found := false + for _, err := range errs { + if strings.Contains(err.Error(), "oauth.hybrid_priority") { + found = true + break + } + } + if !found { + t.Fatalf("expected oauth.hybrid_priority validation error, got %v", errs) + } +} diff --git a/pkg/providers/anthropic_transport.go b/pkg/providers/anthropic_transport.go new file mode 100644 index 0000000..f9502ab --- /dev/null +++ b/pkg/providers/anthropic_transport.go @@ -0,0 +1,107 @@ +package providers + +import ( + "net" + "net/http" + "strings" + "sync" + "time" + + tls "github.com/refraction-networking/utls" + "golang.org/x/net/http2" +) + +type anthropicOAuthRoundTripper struct { + mu sync.Mutex + connections map[string]*http2.ClientConn + pending map[string]*sync.Cond + dialer net.Dialer +} + +func newAnthropicOAuthHTTPClient(timeout time.Duration) *http.Client { + return &http.Client{ + Timeout: timeout, + Transport: newAnthropicOAuthRoundTripper(), + } +} + +func newAnthropicOAuthRoundTripper() *anthropicOAuthRoundTripper { + return &anthropicOAuthRoundTripper{ + connections: map[string]*http2.ClientConn{}, + pending: map[string]*sync.Cond{}, + dialer: net.Dialer{ + Timeout: 15 * time.Second, + KeepAlive: 30 * time.Second, + }, + } +} + +func (t *anthropicOAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + host := req.URL.Hostname() + addr := req.URL.Host + if !strings.Contains(addr, ":") { + addr += ":443" + } + conn, err := t.getOrCreateConnection(host, addr) + if err != nil { + return nil, err + } + resp, err := conn.RoundTrip(req) + if err != nil { + t.mu.Lock() + if cached, ok := t.connections[host]; ok && cached == conn { + delete(t.connections, host) + } + t.mu.Unlock() + return nil, err + } + return resp, nil +} + +func (t *anthropicOAuthRoundTripper) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) { + t.mu.Lock() + if conn, ok := t.connections[host]; ok && conn.CanTakeNewRequest() { + t.mu.Unlock() + return conn, nil + } + if wait, ok := t.pending[host]; ok { + wait.Wait() + if conn, ok := t.connections[host]; ok && conn.CanTakeNewRequest() { + t.mu.Unlock() + return conn, nil + } + } + wait := sync.NewCond(&t.mu) + t.pending[host] = wait + t.mu.Unlock() + + conn, err := t.createConnection(host, addr) + + t.mu.Lock() + defer t.mu.Unlock() + delete(t.pending, host) + wait.Broadcast() + if err != nil { + return nil, err + } + t.connections[host] = conn + return conn, nil +} + +func (t *anthropicOAuthRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) { + rawConn, err := t.dialer.Dial("tcp", addr) + if err != nil { + return nil, err + } + tlsConn := tls.UClient(rawConn, &tls.Config{ServerName: host}, tls.HelloChrome_Auto) + if err := tlsConn.Handshake(); err != nil { + _ = rawConn.Close() + return nil, err + } + h2Conn, err := (&http2.Transport{}).NewClientConn(tlsConn) + if err != nil { + _ = tlsConn.Close() + return nil, err + } + return h2Conn, nil +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index af5963a..a3cd7c0 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -11,12 +11,145 @@ import ( "io" "net/http" "net/url" + "os" + "path/filepath" "regexp" "strings" + "sync" "time" ) +type providerAPIRuntimeState struct { + TokenMasked string `json:"token_masked,omitempty"` + CooldownUntil string `json:"cooldown_until,omitempty"` + FailureCount int `json:"failure_count,omitempty"` + LastFailure string `json:"last_failure,omitempty"` + HealthScore int `json:"health_score,omitempty"` +} + +type providerRuntimeEvent struct { + When string `json:"when,omitempty"` + Kind string `json:"kind,omitempty"` + Target string `json:"target,omitempty"` + Reason string `json:"reason,omitempty"` + Detail string `json:"detail,omitempty"` +} + +func recordProviderRuntimeChange(providerName, kind, target, reason, detail string) { + name := strings.TrimSpace(providerName) + if name == "" || strings.TrimSpace(reason) == "" { + return + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + state.RecentChanges = appendRuntimeEvent(state.RecentChanges, providerRuntimeEvent{ + When: time.Now().Format(time.RFC3339), + Kind: strings.TrimSpace(kind), + Target: strings.TrimSpace(target), + Reason: strings.TrimSpace(reason), + Detail: strings.TrimSpace(detail), + }, runtimeEventLimit(state)) + persistProviderRuntimeLocked(name, state) + providerRuntimeRegistry.api[name] = state +} + +type providerRuntimeCandidate struct { + Kind string `json:"kind,omitempty"` + Target string `json:"target,omitempty"` + Available bool `json:"available"` + Status string `json:"status,omitempty"` + CooldownUntil string `json:"cooldown_until,omitempty"` + HealthScore int `json:"health_score,omitempty"` + FailureCount int `json:"failure_count,omitempty"` +} + +type providerRuntimePersistConfig struct { + Enabled bool + File string + MaxEvents int + Loaded bool + LoadAttempt bool +} + +type ProviderRuntimeQuery struct { + Provider string + Window time.Duration + EventKind string + Reason string + Target string + Limit int + Cursor int + HealthBelow int + CooldownBefore time.Time + Sort string + ChangesOnly bool +} + +type ProviderRefreshAccountResult struct { + Target string `json:"target,omitempty"` + Status string `json:"status,omitempty"` + Detail string `json:"detail,omitempty"` + Expire string `json:"expire,omitempty"` +} + +type ProviderRefreshResult struct { + Provider string `json:"provider,omitempty"` + Checked int `json:"checked,omitempty"` + Refreshed int `json:"refreshed,omitempty"` + Skipped int `json:"skipped,omitempty"` + Failed int `json:"failed,omitempty"` + Accounts []ProviderRefreshAccountResult `json:"accounts,omitempty"` +} + +type ProviderRuntimeSummaryItem struct { + Name string `json:"name,omitempty"` + Auth string `json:"auth,omitempty"` + Status string `json:"status,omitempty"` + APIState providerAPIRuntimeState `json:"api_state,omitempty"` + OAuthAccounts []OAuthAccountInfo `json:"oauth_accounts,omitempty"` + CandidateOrder []providerRuntimeCandidate `json:"candidate_order,omitempty"` + LastSuccess *providerRuntimeEvent `json:"last_success,omitempty"` + LastSuccessAt string `json:"last_success_at,omitempty"` + LastError *providerRuntimeEvent `json:"last_error,omitempty"` + LastErrorAt string `json:"last_error_at,omitempty"` + LastErrorReason string `json:"last_error_reason,omitempty"` + TopCandidateChangedAt string `json:"top_candidate_changed_at,omitempty"` + StaleForSec int64 `json:"stale_for_sec,omitempty"` + InCooldown bool `json:"in_cooldown"` + LowHealth bool `json:"low_health"` + HasRecentErrors bool `json:"has_recent_errors"` + TopCandidate *providerRuntimeCandidate `json:"top_candidate,omitempty"` +} + +type ProviderRuntimeSummary struct { + TotalProviders int `json:"total_providers"` + Healthy int `json:"healthy"` + Degraded int `json:"degraded"` + Critical int `json:"critical"` + InCooldown int `json:"in_cooldown"` + LowHealth int `json:"low_health"` + RecentErrors int `json:"recent_errors"` + Providers []ProviderRuntimeSummaryItem `json:"providers,omitempty"` +} + +type providerRuntimeState struct { + API providerAPIRuntimeState `json:"api_state,omitempty"` + RecentHits []providerRuntimeEvent `json:"recent_hits,omitempty"` + RecentErrors []providerRuntimeEvent `json:"recent_errors,omitempty"` + RecentChanges []providerRuntimeEvent `json:"recent_changes,omitempty"` + LastSuccess *providerRuntimeEvent `json:"last_success,omitempty"` + CandidateOrder []providerRuntimeCandidate `json:"candidate_order,omitempty"` + Persist providerRuntimePersistConfig `json:"-"` +} + +var providerRuntimeRegistry = struct { + mu sync.Mutex + api map[string]providerRuntimeState +}{api: map[string]providerRuntimeState{}} + type HTTPProvider struct { + providerName string apiKey string apiBase string defaultModel string @@ -24,11 +157,16 @@ type HTTPProvider struct { authMode string timeout time.Duration httpClient *http.Client + oauth *oauthManager } -func NewHTTPProvider(apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration) *HTTPProvider { +func NewHTTPProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *HTTPProvider { normalizedBase := normalizeAPIBase(apiBase) + if oauth != nil { + oauth.providerName = strings.TrimSpace(providerName) + } return &HTTPProvider{ + providerName: strings.TrimSpace(providerName), apiKey: apiKey, apiBase: normalizedBase, defaultModel: strings.TrimSpace(defaultModel), @@ -36,9 +174,32 @@ func NewHTTPProvider(apiKey, apiBase, defaultModel string, supportsResponsesComp authMode: authMode, timeout: timeout, httpClient: &http.Client{Timeout: timeout}, + oauth: oauth, } } +func ConfigureProviderRuntime(providerName string, pc config.ProviderConfig) { + name := strings.TrimSpace(providerName) + if name == "" { + return + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + state.Persist = providerRuntimePersistConfig{ + Enabled: pc.RuntimePersist, + File: runtimeHistoryFile(name, pc), + MaxEvents: runtimeHistoryMax(pc), + Loaded: state.Persist.Loaded, + LoadAttempt: state.Persist.LoadAttempt, + } + if state.Persist.Enabled && !state.Persist.LoadAttempt { + state.Persist.LoadAttempt = true + loadPersistedProviderRuntimeLocked(name, &state) + } + providerRuntimeRegistry.api[name] = state +} + func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { if p.apiBase == "" { return nil, fmt.Errorf("API base not configured") @@ -488,33 +649,216 @@ func (p *HTTPProvider) postJSONStream(ctx context.Context, endpoint string, payl if err != nil { return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData)) + attempts, err := p.authAttempts(ctx) if err != nil { - return nil, 0, "", fmt.Errorf("failed to create request: %w", err) + return nil, 0, "", err } - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "text/event-stream") - if p.apiKey != "" { - if p.authMode == "oauth" { - req.Header.Set("Authorization", "Bearer "+p.apiKey) - } else if strings.Contains(p.apiBase, "googleapis.com") { - req.Header.Set("x-goog-api-key", p.apiKey) - } else { - req.Header.Set("Authorization", "Bearer "+p.apiKey) + var lastBody []byte + var lastStatus int + var lastType string + for _, attempt := range attempts { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData)) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + applyAttemptAuth(req, attempt) + + body, status, ctype, quotaHit, err := p.doStreamAttempt(req, onEvent) + if err != nil { + return nil, 0, "", err + } + if !quotaHit { + p.markAttemptSuccess(attempt) + return body, status, ctype, nil + } + lastBody, lastStatus, lastType = body, status, ctype + if attempt.kind == "oauth" && attempt.session != nil && p.oauth != nil { + reason, _ := classifyOAuthFailure(status, body) + p.oauth.markExhausted(attempt.session, reason) + recordProviderOAuthError(p.providerName, attempt.session, reason) + } + if attempt.kind == "api_key" { + reason, _ := classifyOAuthFailure(status, body) + p.markAPIKeyFailure(reason) } } + return lastBody, lastStatus, lastType, nil +} + +func (p *HTTPProvider) postJSON(ctx context.Context, endpoint string, payload interface{}) ([]byte, int, string, error) { + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) + } + attempts, err := p.authAttempts(ctx) + if err != nil { + return nil, 0, "", err + } + var lastBody []byte + var lastStatus int + var lastType string + for _, attempt := range attempts { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData)) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + applyAttemptAuth(req, attempt) + + body, status, ctype, err := p.doJSONAttempt(req) + if err != nil { + return nil, 0, "", err + } + reason, retry := classifyOAuthFailure(status, body) + if !retry { + p.markAttemptSuccess(attempt) + return body, status, ctype, nil + } + lastBody, lastStatus, lastType = body, status, ctype + if attempt.kind == "oauth" && attempt.session != nil && p.oauth != nil { + p.oauth.markExhausted(attempt.session, reason) + recordProviderOAuthError(p.providerName, attempt.session, reason) + } + if attempt.kind == "api_key" { + p.markAPIKeyFailure(reason) + } + } + return lastBody, lastStatus, lastType, nil +} + +type authAttempt struct { + session *oauthSession + token string + kind string +} + +func (p *HTTPProvider) authAttempts(ctx context.Context) ([]authAttempt, error) { + mode := strings.ToLower(strings.TrimSpace(p.authMode)) + if mode == "oauth" || mode == "hybrid" { + out := make([]authAttempt, 0, 1) + apiAttempt, apiReady := p.apiKeyAttempt() + if p.oauth == nil { + if mode == "hybrid" && apiReady { + return []authAttempt{apiAttempt}, nil + } + return nil, fmt.Errorf("oauth is enabled but provider session manager is not configured") + } + attempts, err := p.oauth.prepareAttemptsLocked(ctx) + if err != nil { + return nil, err + } + oauthAttempts := make([]authAttempt, 0, len(attempts)) + for _, attempt := range attempts { + oauthAttempts = append(oauthAttempts, authAttempt{session: attempt.Session, token: attempt.Token, kind: "oauth"}) + } + if mode == "hybrid" && apiReady && p.oauth.cfg.HybridPriority != "oauth_first" { + out = append(out, apiAttempt) + } + if len(attempts) == 0 { + if len(out) > 0 { + p.updateCandidateOrder(out) + return out, nil + } + return nil, fmt.Errorf("oauth session not found, run `clawgo provider login` first") + } + out = append(out, oauthAttempts...) + if mode == "hybrid" && apiReady && p.oauth.cfg.HybridPriority == "oauth_first" { + out = append(out, apiAttempt) + } + p.updateCandidateOrder(out) + return out, nil + } + apiAttempt, apiReady := p.apiKeyAttempt() + if !apiReady { + return nil, fmt.Errorf("api key temporarily unavailable") + } + out := []authAttempt{apiAttempt} + p.updateCandidateOrder(out) + return out, nil +} + +func (p *HTTPProvider) updateCandidateOrder(attempts []authAttempt) { + name := strings.TrimSpace(p.providerName) + if name == "" { + return + } + candidates := make([]providerRuntimeCandidate, 0, len(attempts)) + for _, attempt := range attempts { + candidate := providerRuntimeCandidate{ + Kind: attempt.kind, + Available: true, + Status: "ready", + } + if attempt.kind == "api_key" { + candidate.Target = maskToken(p.apiKey) + candidate.HealthScore = providerAPIHealth(name) + } else if attempt.session != nil { + candidate.Target = firstNonEmpty(attempt.session.Email, attempt.session.AccountID, attempt.session.FilePath) + candidate.HealthScore = sessionHealthScore(attempt.session) + candidate.FailureCount = attempt.session.FailureCount + candidate.CooldownUntil = attempt.session.CooldownUntil + } + candidates = append(candidates, candidate) + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + if !providerCandidatesEqual(state.CandidateOrder, candidates) { + state.RecentChanges = appendRuntimeEvent(state.RecentChanges, providerRuntimeEvent{ + When: time.Now().Format(time.RFC3339), + Kind: "scheduler", + Target: name, + Reason: "candidate_order_changed", + Detail: candidateOrderChangeDetail(state.CandidateOrder, candidates), + }, runtimeEventLimit(state)) + } + state.CandidateOrder = candidates + persistProviderRuntimeLocked(name, state) + providerRuntimeRegistry.api[name] = state +} + +func applyAttemptAuth(req *http.Request, attempt authAttempt) { + if req == nil { + return + } + if strings.TrimSpace(attempt.token) == "" { + return + } + if strings.Contains(req.URL.Host, "googleapis.com") { + req.Header.Set("x-goog-api-key", attempt.token) + return + } + req.Header.Set("Authorization", "Bearer "+attempt.token) +} + +func (p *HTTPProvider) doJSONAttempt(req *http.Request) ([]byte, int, string, error) { resp, err := p.httpClient.Do(req) if err != nil { return nil, 0, "", fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), fmt.Errorf("failed to read response: %w", readErr) + } + return body, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil +} + +func (p *HTTPProvider) doStreamAttempt(req *http.Request, onEvent func(string)) ([]byte, int, string, bool, error) { + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, 0, "", false, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() ctype := strings.TrimSpace(resp.Header.Get("Content-Type")) if !strings.Contains(strings.ToLower(ctype), "text/event-stream") { body, readErr := io.ReadAll(resp.Body) if readErr != nil { - return nil, resp.StatusCode, ctype, fmt.Errorf("failed to read response: %w", readErr) + return nil, resp.StatusCode, ctype, false, fmt.Errorf("failed to read response: %w", readErr) } - return body, resp.StatusCode, ctype, nil + return body, resp.StatusCode, ctype, shouldRetryOAuthQuota(resp.StatusCode, body), nil } scanner := bufio.NewScanner(resp.Body) @@ -556,46 +900,903 @@ func (p *HTTPProvider) postJSONStream(ctx context.Context, endpoint string, payl } } if err := scanner.Err(); err != nil { - return nil, resp.StatusCode, ctype, fmt.Errorf("failed to read stream: %w", err) + return nil, resp.StatusCode, ctype, false, fmt.Errorf("failed to read stream: %w", err) } if len(finalJSON) == 0 { finalJSON = []byte("{}") } - return finalJSON, resp.StatusCode, ctype, nil + return finalJSON, resp.StatusCode, ctype, false, nil } -func (p *HTTPProvider) postJSON(ctx context.Context, endpoint string, payload interface{}) ([]byte, int, string, error) { - jsonData, err := json.Marshal(payload) - if err != nil { - return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) - } +func shouldRetryOAuthQuota(status int, body []byte) bool { + _, retry := classifyOAuthFailure(status, body) + return retry +} - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData)) - if err != nil { - return nil, 0, "", fmt.Errorf("failed to create request: %w", err) +func classifyOAuthFailure(status int, body []byte) (oauthFailureReason, bool) { + if status != http.StatusTooManyRequests && status != http.StatusPaymentRequired && status != http.StatusForbidden { + return "", false } - req.Header.Set("Content-Type", "application/json") - if p.apiKey != "" { - if p.authMode == "oauth" { - req.Header.Set("Authorization", "Bearer "+p.apiKey) - } else if strings.Contains(p.apiBase, "googleapis.com") { - req.Header.Set("x-goog-api-key", p.apiKey) - } else { - req.Header.Set("Authorization", "Bearer "+p.apiKey) + lower := strings.ToLower(string(body)) + if strings.Contains(lower, "insufficient_quota") || strings.Contains(lower, "quota") || strings.Contains(lower, "billing") { + return oauthFailureQuota, true + } + if strings.Contains(lower, "rate limit") || strings.Contains(lower, "rate_limit") || strings.Contains(lower, "usage limit") { + return oauthFailureRateLimit, true + } + if status == http.StatusForbidden { + return oauthFailureForbidden, true + } + if status == http.StatusTooManyRequests { + return oauthFailureRateLimit, true + } + return "", false +} + +func (p *HTTPProvider) markAttemptSuccess(attempt authAttempt) { + if attempt.kind == "oauth" && attempt.session != nil && p.oauth != nil { + p.oauth.markSuccess(attempt.session) + } + if attempt.kind == "api_key" { + p.markAPIKeySuccess() + } + p.recordProviderHit(attempt, "") +} + +func (p *HTTPProvider) markAPIKeyFailure(reason oauthFailureReason) { + name := strings.TrimSpace(p.providerName) + if name == "" { + return + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + if state.API.HealthScore <= 0 { + state.API.HealthScore = 100 + } + state.API.FailureCount++ + state.API.LastFailure = string(reason) + state.API.HealthScore = maxInt(1, state.API.HealthScore-healthPenaltyForReason(reason)) + cooldown := 15 * time.Minute + switch reason { + case oauthFailureQuota: + cooldown = 60 * time.Minute + case oauthFailureForbidden: + cooldown = 30 * time.Minute + } + state.API.CooldownUntil = time.Now().Add(cooldown).Format(time.RFC3339) + state.API.TokenMasked = maskToken(p.apiKey) + state.RecentErrors = appendRuntimeEvent(state.RecentErrors, providerRuntimeEvent{ + When: time.Now().Format(time.RFC3339), + Kind: "api_key", + Target: maskToken(p.apiKey), + Reason: string(reason), + }, runtimeEventLimit(state)) + state.RecentChanges = appendRuntimeEvent(state.RecentChanges, providerRuntimeEvent{ + When: time.Now().Format(time.RFC3339), + Kind: "api_key", + Target: maskToken(p.apiKey), + Reason: "api_key_cooldown_" + string(reason), + Detail: "api key entered cooldown after request failure", + }, runtimeEventLimit(state)) + persistProviderRuntimeLocked(name, state) + providerRuntimeRegistry.api[name] = state +} + +func (p *HTTPProvider) markAPIKeySuccess() { + name := strings.TrimSpace(p.providerName) + if name == "" { + return + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + if state.API.HealthScore <= 0 { + state.API.HealthScore = 100 + } else { + state.API.HealthScore = minInt(100, state.API.HealthScore+3) + } + wasCooling := strings.TrimSpace(state.API.CooldownUntil) != "" + state.API.CooldownUntil = "" + state.API.TokenMasked = maskToken(p.apiKey) + if wasCooling { + state.RecentChanges = appendRuntimeEvent(state.RecentChanges, providerRuntimeEvent{ + When: time.Now().Format(time.RFC3339), + Kind: "api_key", + Target: maskToken(p.apiKey), + Reason: "api_key_recovered", + Detail: "api key cooldown cleared after successful request", + }, runtimeEventLimit(state)) + } + persistProviderRuntimeLocked(name, state) + providerRuntimeRegistry.api[name] = state +} + +func (p *HTTPProvider) apiKeyAttempt() (authAttempt, bool) { + token := strings.TrimSpace(p.apiKey) + if token == "" { + return authAttempt{}, false + } + name := strings.TrimSpace(p.providerName) + if name == "" { + return authAttempt{token: token, kind: "api_key"}, true + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + if state.API.TokenMasked == "" { + state.API.TokenMasked = maskToken(token) + } + if state.API.HealthScore <= 0 { + state.API.HealthScore = 100 + } + if state.API.CooldownUntil != "" { + if until, err := time.Parse(time.RFC3339, state.API.CooldownUntil); err == nil { + if time.Now().Before(until) { + providerRuntimeRegistry.api[name] = state + return authAttempt{}, false + } + } + state.API.CooldownUntil = "" + } + providerRuntimeRegistry.api[name] = state + return authAttempt{token: token, kind: "api_key"}, true +} + +func providerAPIHealth(name string) int { + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + if state.API.HealthScore <= 0 { + return 100 + } + return state.API.HealthScore +} + +func maskToken(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + if len(value) <= 8 { + return value[:2] + "***" + } + return value[:4] + "***" + value[len(value)-4:] +} + +func appendRuntimeEvent(events []providerRuntimeEvent, event providerRuntimeEvent, limit int) []providerRuntimeEvent { + out := append([]providerRuntimeEvent{event}, events...) + if limit <= 0 { + limit = 8 + } + if len(out) > limit { + out = out[:limit] + } + return out +} + +func (p *HTTPProvider) recordProviderHit(attempt authAttempt, reason string) { + name := strings.TrimSpace(p.providerName) + if name == "" { + return + } + target := "" + if attempt.kind == "api_key" { + target = maskToken(p.apiKey) + } else if attempt.session != nil { + target = firstNonEmpty(attempt.session.Email, attempt.session.AccountID, attempt.session.FilePath) + } + event := providerRuntimeEvent{ + When: time.Now().Format(time.RFC3339), + Kind: attempt.kind, + Target: target, + Reason: reason, + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + state.RecentHits = appendRuntimeEvent(state.RecentHits, event, runtimeEventLimit(state)) + state.LastSuccess = &event + persistProviderRuntimeLocked(name, state) + providerRuntimeRegistry.api[name] = state +} + +func recordProviderOAuthError(providerName string, session *oauthSession, reason oauthFailureReason) { + name := strings.TrimSpace(providerName) + if name == "" || session == nil { + return + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + state.RecentErrors = appendRuntimeEvent(state.RecentErrors, providerRuntimeEvent{ + When: time.Now().Format(time.RFC3339), + Kind: "oauth", + Target: firstNonEmpty(session.Email, session.AccountID, session.FilePath), + Reason: string(reason), + }, runtimeEventLimit(state)) + persistProviderRuntimeLocked(name, state) + providerRuntimeRegistry.api[name] = state +} + +func ClearProviderAPICooldown(providerName string) { + name := strings.TrimSpace(providerName) + if name == "" { + return + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + target := state.API.TokenMasked + state.API.CooldownUntil = "" + state.RecentChanges = appendRuntimeEvent(state.RecentChanges, providerRuntimeEvent{ + When: time.Now().Format(time.RFC3339), + Kind: "api_key", + Target: target, + Reason: "manual_clear_api_cooldown", + Detail: "api key cooldown cleared from runtime panel", + }, runtimeEventLimit(state)) + persistProviderRuntimeLocked(name, state) + providerRuntimeRegistry.api[name] = state +} + +func ClearProviderRuntimeHistory(providerName string) { + name := strings.TrimSpace(providerName) + if name == "" { + return + } + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api[name] + state.RecentHits = nil + state.RecentErrors = nil + state.RecentChanges = nil + state.LastSuccess = nil + if state.Persist.Enabled && strings.TrimSpace(state.Persist.File) != "" { + _ = os.Remove(state.Persist.File) + } + providerRuntimeRegistry.api[name] = state +} + +func runtimeEventLimit(state providerRuntimeState) int { + if state.Persist.MaxEvents > 0 { + return state.Persist.MaxEvents + } + return 8 +} + +func runtimeHistoryMax(pc config.ProviderConfig) int { + if pc.RuntimeHistoryMax > 0 { + return pc.RuntimeHistoryMax + } + return 24 +} + +func runtimeHistoryFile(name string, pc config.ProviderConfig) string { + if file := strings.TrimSpace(pc.RuntimeHistoryFile); file != "" { + return file + } + return filepath.Join(config.GetConfigDir(), "runtime", "providers", strings.TrimSpace(name)+".json") +} + +func loadPersistedProviderRuntimeLocked(name string, state *providerRuntimeState) { + if state == nil || !state.Persist.Enabled || strings.TrimSpace(state.Persist.File) == "" { + return + } + raw, err := os.ReadFile(state.Persist.File) + if err != nil { + if os.IsNotExist(err) { + state.Persist.Loaded = true + } + return + } + var persisted providerRuntimeState + if err := json.Unmarshal(raw, &persisted); err != nil { + state.Persist.Loaded = true + return + } + if state.API == (providerAPIRuntimeState{}) { + state.API = persisted.API + } + if len(state.RecentHits) == 0 { + state.RecentHits = persisted.RecentHits + } + if len(state.RecentErrors) == 0 { + state.RecentErrors = persisted.RecentErrors + } + if len(state.RecentChanges) == 0 { + state.RecentChanges = persisted.RecentChanges + } + if state.LastSuccess == nil && persisted.LastSuccess != nil { + last := *persisted.LastSuccess + state.LastSuccess = &last + } + if len(state.CandidateOrder) == 0 { + state.CandidateOrder = persisted.CandidateOrder + } + state.Persist.Loaded = true +} + +func persistProviderRuntimeLocked(name string, state providerRuntimeState) { + if !state.Persist.Enabled || strings.TrimSpace(state.Persist.File) == "" { + return + } + if err := os.MkdirAll(filepath.Dir(state.Persist.File), 0o700); err != nil { + return + } + payload := providerRuntimeState{ + API: state.API, + RecentHits: trimRuntimeEvents(state.RecentHits, runtimeEventLimit(state)), + RecentErrors: trimRuntimeEvents(state.RecentErrors, runtimeEventLimit(state)), + RecentChanges: trimRuntimeEvents(state.RecentChanges, runtimeEventLimit(state)), + LastSuccess: state.LastSuccess, + CandidateOrder: state.CandidateOrder, + } + raw, err := json.MarshalIndent(payload, "", " ") + if err != nil { + return + } + _ = os.WriteFile(state.Persist.File, raw, 0o600) +} + +func trimRuntimeEvents(events []providerRuntimeEvent, limit int) []providerRuntimeEvent { + if limit <= 0 || len(events) <= limit { + return events + } + return events[:limit] +} + +func eventTimeUnix(event providerRuntimeEvent) int64 { + when, err := time.Parse(time.RFC3339, strings.TrimSpace(event.When)) + if err != nil { + return 0 + } + return when.Unix() +} + +func filterRuntimeEvents(events []providerRuntimeEvent, query ProviderRuntimeQuery) []providerRuntimeEvent { + if len(events) == 0 { + return nil + } + kind := strings.TrimSpace(query.EventKind) + reason := strings.TrimSpace(query.Reason) + target := strings.ToLower(strings.TrimSpace(query.Target)) + var cutoff time.Time + if query.Window > 0 { + cutoff = time.Now().Add(-query.Window) + } + filtered := make([]providerRuntimeEvent, 0, len(events)) + for _, event := range events { + if !cutoff.IsZero() { + when, err := time.Parse(time.RFC3339, strings.TrimSpace(event.When)) + if err != nil || when.Before(cutoff) { + continue + } + } + if kind != "" && !strings.EqualFold(strings.TrimSpace(event.Kind), kind) { + continue + } + if reason != "" && !strings.Contains(strings.ToLower(strings.TrimSpace(event.Reason)), strings.ToLower(reason)) { + continue + } + if target != "" && !strings.Contains(strings.ToLower(strings.TrimSpace(event.Target)), target) && !strings.Contains(strings.ToLower(strings.TrimSpace(event.Detail)), target) { + continue + } + filtered = append(filtered, event) + } + return filtered +} + +func mergeRuntimeEvents(item map[string]interface{}, query ProviderRuntimeQuery) ([]providerRuntimeEvent, int) { + hits, _ := item["recent_hits"].([]providerRuntimeEvent) + errors, _ := item["recent_errors"].([]providerRuntimeEvent) + changes, _ := item["recent_changes"].([]providerRuntimeEvent) + merged := make([]providerRuntimeEvent, 0, len(hits)+len(errors)+len(changes)) + if !query.ChangesOnly { + merged = append(merged, filterRuntimeEvents(hits, query)...) + merged = append(merged, filterRuntimeEvents(errors, query)...) + } + merged = append(merged, filterRuntimeEvents(changes, query)...) + desc := !strings.EqualFold(strings.TrimSpace(query.Sort), "asc") + for i := 0; i < len(merged); i++ { + for j := i + 1; j < len(merged); j++ { + left := eventTimeUnix(merged[i]) + right := eventTimeUnix(merged[j]) + swap := right > left + if !desc { + swap = right < left + } + if swap { + merged[i], merged[j] = merged[j], merged[i] + } } } + start := query.Cursor + if start < 0 { + start = 0 + } + if start > len(merged) { + start = len(merged) + } + limit := query.Limit + if limit <= 0 { + limit = 20 + } + end := start + limit + if end > len(merged) { + end = len(merged) + } + nextCursor := 0 + if end < len(merged) { + nextCursor = end + } + return merged[start:end], nextCursor +} - resp, err := p.httpClient.Do(req) +func matchesProviderCandidateFilters(item map[string]interface{}, query ProviderRuntimeQuery) bool { + if query.HealthBelow <= 0 && query.CooldownBefore.IsZero() { + return true + } + apiState, _ := item["api_state"].(providerAPIRuntimeState) + candidates, _ := item["candidate_order"].([]providerRuntimeCandidate) + if query.HealthBelow > 0 { + if runtimeHealthValue(apiState.HealthScore) < query.HealthBelow { + return true + } + for _, candidate := range candidates { + if runtimeHealthValue(candidate.HealthScore) < query.HealthBelow { + return true + } + } + } + if !query.CooldownBefore.IsZero() { + values := []string{apiState.CooldownUntil} + for _, candidate := range candidates { + values = append(values, candidate.CooldownUntil) + } + for _, value := range values { + if strings.TrimSpace(value) == "" { + continue + } + until, err := time.Parse(time.RFC3339, strings.TrimSpace(value)) + if err == nil && until.Before(query.CooldownBefore) { + return true + } + } + } + return false +} + +func providerInCooldown(item map[string]interface{}) bool { + apiState, _ := item["api_state"].(providerAPIRuntimeState) + if cooldownActive(apiState.CooldownUntil) { + return true + } + candidates, _ := item["candidate_order"].([]providerRuntimeCandidate) + for _, candidate := range candidates { + if cooldownActive(candidate.CooldownUntil) { + return true + } + } + return false +} + +func cooldownActive(value string) bool { + if strings.TrimSpace(value) == "" { + return false + } + until, err := time.Parse(time.RFC3339, strings.TrimSpace(value)) + return err == nil && time.Now().Before(until) +} + +func buildProviderCandidateOrder(_ string, pc config.ProviderConfig, accounts []OAuthAccountInfo, api providerAPIRuntimeState) []providerRuntimeCandidate { + authMode := strings.ToLower(strings.TrimSpace(pc.Auth)) + apiCandidate := providerRuntimeCandidate{ + Kind: "api_key", + Target: maskToken(pc.APIKey), + Available: strings.TrimSpace(pc.APIKey) != "", + Status: "ready", + CooldownUntil: strings.TrimSpace(api.CooldownUntil), + HealthScore: runtimeHealthValue(api.HealthScore), + FailureCount: api.FailureCount, + } + if strings.TrimSpace(apiCandidate.CooldownUntil) != "" { + if until, err := time.Parse(time.RFC3339, apiCandidate.CooldownUntil); err == nil && time.Now().Before(until) { + apiCandidate.Available = false + apiCandidate.Status = "cooldown" + } + } + oauthAvailable := make([]providerRuntimeCandidate, 0, len(accounts)) + oauthUnavailable := make([]providerRuntimeCandidate, 0, len(accounts)) + for _, account := range accounts { + candidate := providerRuntimeCandidate{ + Kind: "oauth", + Target: firstNonEmpty(account.Email, account.AccountID, account.CredentialFile), + Available: true, + Status: "ready", + CooldownUntil: strings.TrimSpace(account.CooldownUntil), + HealthScore: runtimeHealthValue(account.HealthScore), + FailureCount: account.FailureCount, + } + if strings.TrimSpace(candidate.CooldownUntil) != "" { + if until, err := time.Parse(time.RFC3339, candidate.CooldownUntil); err == nil && time.Now().Before(until) { + candidate.Available = false + candidate.Status = "cooldown" + } + } + if candidate.Available { + oauthAvailable = append(oauthAvailable, candidate) + } else { + oauthUnavailable = append(oauthUnavailable, candidate) + } + } + sortRuntimeCandidates(oauthAvailable) + sortRuntimeCandidates(oauthUnavailable) + out := make([]providerRuntimeCandidate, 0, 1+len(accounts)) + switch authMode { + case "oauth": + out = append(out, oauthAvailable...) + case "hybrid": + if strings.EqualFold(strings.TrimSpace(pc.OAuth.HybridPriority), "oauth_first") { + out = append(out, oauthAvailable...) + if apiCandidate.Target != "" && apiCandidate.Available { + out = append(out, apiCandidate) + } + } else { + if apiCandidate.Target != "" && apiCandidate.Available { + out = append(out, apiCandidate) + } + out = append(out, oauthAvailable...) + } + case "none": + default: + if apiCandidate.Target != "" { + out = append(out, apiCandidate) + } + } + if authMode == "hybrid" { + if apiCandidate.Target != "" && !apiCandidate.Available { + out = append(out, apiCandidate) + } + out = append(out, oauthUnavailable...) + } else if authMode == "oauth" { + out = append(out, oauthUnavailable...) + } + return out +} + +func runtimeHealthValue(value int) int { + if value <= 0 { + return 100 + } + return value +} + +func sortRuntimeCandidates(items []providerRuntimeCandidate) { + for i := 0; i < len(items); i++ { + for j := i + 1; j < len(items); j++ { + if items[j].HealthScore > items[i].HealthScore || (items[j].HealthScore == items[i].HealthScore && items[j].Target < items[i].Target) { + items[i], items[j] = items[j], items[i] + } + } + } +} + +func providerCandidatesEqual(left, right []providerRuntimeCandidate) bool { + if len(left) != len(right) { + return false + } + for i := range left { + if left[i].Kind != right[i].Kind || + left[i].Target != right[i].Target || + left[i].Available != right[i].Available || + left[i].Status != right[i].Status || + left[i].CooldownUntil != right[i].CooldownUntil || + left[i].HealthScore != right[i].HealthScore || + left[i].FailureCount != right[i].FailureCount { + return false + } + } + return true +} + +func summarizeCandidate(candidate providerRuntimeCandidate) string { + target := strings.TrimSpace(candidate.Target) + if target == "" { + target = "-" + } + return strings.TrimSpace(candidate.Kind) + ":" + target +} + +func candidateOrderChangeDetail(before, after []providerRuntimeCandidate) string { + if len(before) == 0 && len(after) == 0 { + return "" + } + beforeTop := "-" + afterTop := "-" + if len(before) > 0 { + beforeTop = summarizeCandidate(before[0]) + } + if len(after) > 0 { + afterTop = summarizeCandidate(after[0]) + } + beforeOrder := make([]string, 0, len(before)) + for _, item := range before { + beforeOrder = append(beforeOrder, summarizeCandidate(item)) + } + afterOrder := make([]string, 0, len(after)) + for _, item := range after { + afterOrder = append(afterOrder, summarizeCandidate(item)) + } + return fmt.Sprintf("top %s -> %s | order [%s] -> [%s]", beforeTop, afterTop, strings.Join(beforeOrder, " > "), strings.Join(afterOrder, " > ")) +} + +func GetProviderRuntimeSnapshot(cfg *config.Config) map[string]interface{} { + if cfg == nil { + return map[string]interface{}{"items": []interface{}{}} + } + items := make([]map[string]interface{}, 0) + configs := getAllProviderConfigs(cfg) + for name, pc := range configs { + ConfigureProviderRuntime(name, pc) + providerRuntimeRegistry.mu.Lock() + state := providerRuntimeRegistry.api[name] + providerRuntimeRegistry.mu.Unlock() + item := map[string]interface{}{ + "name": name, + "auth": strings.TrimSpace(pc.Auth), + "api_base": strings.TrimSpace(pc.APIBase), + "api_state": state.API, + "recent_hits": state.RecentHits, + "recent_errors": state.RecentErrors, + "recent_changes": state.RecentChanges, + "last_success": state.LastSuccess, + } + candidateOrder := state.CandidateOrder + if strings.EqualFold(strings.TrimSpace(pc.Auth), "oauth") || strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") { + if mgr, err := NewOAuthLoginManager(pc, time.Duration(maxInt(pc.TimeoutSec, 90))*time.Second); err == nil { + if accounts, err := mgr.ListAccounts(); err == nil { + item["oauth_accounts"] = accounts + candidateOrder = buildProviderCandidateOrder(name, pc, accounts, state.API) + } + } + } else if len(candidateOrder) == 0 && strings.TrimSpace(pc.APIKey) != "" { + candidateOrder = buildProviderCandidateOrder(name, pc, nil, state.API) + } + if len(candidateOrder) > 0 { + providerRuntimeRegistry.mu.Lock() + state = providerRuntimeRegistry.api[name] + state.CandidateOrder = candidateOrder + persistProviderRuntimeLocked(name, state) + providerRuntimeRegistry.api[name] = state + providerRuntimeRegistry.mu.Unlock() + } + item["candidate_order"] = candidateOrder + items = append(items, item) + } + return map[string]interface{}{"items": items} +} + +func GetProviderRuntimeView(cfg *config.Config, query ProviderRuntimeQuery) map[string]interface{} { + if cfg == nil { + return map[string]interface{}{"items": []interface{}{}} + } + snapshot := GetProviderRuntimeSnapshot(cfg) + rawItems, _ := snapshot["items"].([]map[string]interface{}) + if len(rawItems) == 0 { + return map[string]interface{}{"items": []interface{}{}} + } + filterName := strings.TrimSpace(query.Provider) + items := make([]map[string]interface{}, 0, len(rawItems)) + for _, item := range rawItems { + name := strings.TrimSpace(fmt.Sprintf("%v", item["name"])) + if filterName != "" && name != filterName { + continue + } + next := map[string]interface{}{} + for key, value := range item { + next[key] = value + } + hits, _ := item["recent_hits"].([]providerRuntimeEvent) + errors, _ := item["recent_errors"].([]providerRuntimeEvent) + changes, _ := item["recent_changes"].([]providerRuntimeEvent) + next["recent_hits"] = filterRuntimeEvents(hits, query) + next["recent_errors"] = filterRuntimeEvents(errors, query) + next["recent_changes"] = filterRuntimeEvents(changes, query) + if query.ChangesOnly { + next["recent_hits"] = []providerRuntimeEvent{} + next["recent_errors"] = []providerRuntimeEvent{} + } + events, nextCursor := mergeRuntimeEvents(next, query) + next["events"] = events + next["next_cursor"] = nextCursor + if !matchesProviderCandidateFilters(next, query) { + continue + } + items = append(items, next) + } + return map[string]interface{}{"items": items} +} + +func GetProviderRuntimeSummary(cfg *config.Config, query ProviderRuntimeQuery) ProviderRuntimeSummary { + snapshot := GetProviderRuntimeSnapshot(cfg) + rawItems, _ := snapshot["items"].([]map[string]interface{}) + summary := ProviderRuntimeSummary{Providers: make([]ProviderRuntimeSummaryItem, 0, len(rawItems))} + for _, item := range rawItems { + name := strings.TrimSpace(fmt.Sprintf("%v", item["name"])) + if strings.TrimSpace(query.Provider) != "" && name != strings.TrimSpace(query.Provider) { + continue + } + auth := strings.TrimSpace(fmt.Sprintf("%v", item["auth"])) + apiState, _ := item["api_state"].(providerAPIRuntimeState) + accounts, _ := item["oauth_accounts"].([]OAuthAccountInfo) + candidates, _ := item["candidate_order"].([]providerRuntimeCandidate) + errors, _ := item["recent_errors"].([]providerRuntimeEvent) + changes, _ := item["recent_changes"].([]providerRuntimeEvent) + errors = filterRuntimeEvents(errors, query) + changes = filterRuntimeEvents(changes, query) + lastSuccess, _ := item["last_success"].(*providerRuntimeEvent) + inCooldown := providerInCooldown(item) + lowHealth := matchesProviderCandidateFilters(item, ProviderRuntimeQuery{HealthBelow: maxInt(query.HealthBelow, 1)}) + hasRecentErrors := len(errors) > 0 + lastError := latestProviderRuntimeEvent(errors) + topChangedAt := latestRuntimeChangeAt(changes, "candidate_order_changed") + status := providerRuntimeSummaryStatus(inCooldown, lowHealth, hasRecentErrors) + providerItem := ProviderRuntimeSummaryItem{ + Name: name, + Auth: auth, + Status: status, + APIState: apiState, + OAuthAccounts: accounts, + CandidateOrder: candidates, + LastSuccess: lastSuccess, + LastError: lastError, + TopCandidateChangedAt: topChangedAt, + InCooldown: inCooldown, + LowHealth: lowHealth, + HasRecentErrors: hasRecentErrors, + } + if lastSuccess != nil { + providerItem.LastSuccessAt = strings.TrimSpace(lastSuccess.When) + if when := parseRuntimeEventTime(*lastSuccess); !when.IsZero() { + providerItem.StaleForSec = int64(time.Since(when).Seconds()) + } + } else { + providerItem.StaleForSec = -1 + } + if lastError != nil { + providerItem.LastErrorAt = strings.TrimSpace(lastError.When) + providerItem.LastErrorReason = strings.TrimSpace(lastError.Reason) + } + if len(candidates) > 0 { + top := candidates[0] + providerItem.TopCandidate = &top + } + summary.TotalProviders++ + switch status { + case "critical": + summary.Critical++ + case "degraded": + summary.Degraded++ + default: + summary.Healthy++ + } + if inCooldown { + summary.InCooldown++ + } + if lowHealth { + summary.LowHealth++ + } + if hasRecentErrors { + summary.RecentErrors++ + } + if inCooldown || lowHealth || hasRecentErrors || strings.TrimSpace(query.Provider) != "" { + summary.Providers = append(summary.Providers, providerItem) + } + } + return summary +} + +func latestProviderRuntimeEvent(events []providerRuntimeEvent) *providerRuntimeEvent { + if len(events) == 0 { + return nil + } + best := events[0] + bestTime := eventTimeUnix(best) + for i := 1; i < len(events); i++ { + currentTime := eventTimeUnix(events[i]) + if currentTime > bestTime { + best = events[i] + bestTime = currentTime + } + } + copyEvent := best + return ©Event +} + +func latestRuntimeChangeAt(events []providerRuntimeEvent, reason string) string { + targetReason := strings.TrimSpace(reason) + if targetReason == "" || len(events) == 0 { + return "" + } + var latest *providerRuntimeEvent + var latestUnix int64 + for i := range events { + if !strings.EqualFold(strings.TrimSpace(events[i].Reason), targetReason) { + continue + } + currentUnix := eventTimeUnix(events[i]) + if latest == nil || currentUnix > latestUnix { + eventCopy := events[i] + latest = &eventCopy + latestUnix = currentUnix + } + } + if latest == nil { + return "" + } + return strings.TrimSpace(latest.When) +} + +func parseRuntimeEventTime(event providerRuntimeEvent) time.Time { + when, err := time.Parse(time.RFC3339, strings.TrimSpace(event.When)) if err != nil { - return nil, 0, "", fmt.Errorf("failed to send request: %w", err) + return time.Time{} } - defer resp.Body.Close() + return when +} - body, readErr := io.ReadAll(resp.Body) - if readErr != nil { - return nil, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), fmt.Errorf("failed to read response: %w", readErr) +func providerRuntimeSummaryStatus(inCooldown, lowHealth, hasRecentErrors bool) string { + if inCooldown || lowHealth { + return "critical" } - return body, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil + if hasRecentErrors { + return "degraded" + } + return "healthy" +} + +func RefreshProviderRuntimeNow(cfg *config.Config, providerName string, onlyExpiring bool) (*ProviderRefreshResult, error) { + pc, err := getProviderConfigByName(cfg, providerName) + if err != nil { + return nil, err + } + if !strings.EqualFold(strings.TrimSpace(pc.Auth), "oauth") && !strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") { + return nil, fmt.Errorf("provider %q does not use oauth", providerName) + } + manager, err := newOAuthManager(pc, time.Duration(maxInt(pc.TimeoutSec, 90))*time.Second) + if err != nil { + return nil, err + } + defer manager.bgCancel() + manager.providerName = strings.TrimSpace(providerName) + lead := 365 * 24 * time.Hour + if onlyExpiring { + lead = manager.cfg.RefreshLead + if lead <= 0 { + lead = 30 * time.Minute + } + } + return manager.refreshExpiringSessions(context.Background(), lead) +} + +func RerankProviderRuntime(cfg *config.Config, providerName string) ([]providerRuntimeCandidate, error) { + provider, err := CreateProviderByName(cfg, providerName) + if err != nil { + return nil, err + } + httpProvider, ok := provider.(*HTTPProvider) + if !ok { + return nil, fmt.Errorf("provider %q does not support runtime rerank", providerName) + } + _, err = httpProvider.authAttempts(context.Background()) + if err != nil && !strings.Contains(strings.ToLower(err.Error()), "oauth session not found") { + return nil, err + } + providerRuntimeRegistry.mu.Lock() + order := append([]providerRuntimeCandidate(nil), providerRuntimeRegistry.api[strings.TrimSpace(providerName)].CandidateOrder...) + providerRuntimeRegistry.mu.Unlock() + return order, nil } func parseResponsesAPIResponse(body []byte) (*LLMResponse, error) { @@ -942,6 +2143,7 @@ func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error) if err != nil { return nil, err } + ConfigureProviderRuntime(name, pc) if pc.APIBase == "" { return nil, fmt.Errorf("no API base configured for provider %q", name) } @@ -952,7 +2154,14 @@ func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error) if len(pc.Models) > 0 { defaultModel = pc.Models[0] } - return NewHTTPProvider(pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second), nil + var oauth *oauthManager + if strings.EqualFold(strings.TrimSpace(pc.Auth), "oauth") || strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") { + oauth, err = newOAuthManager(pc, time.Duration(pc.TimeoutSec)*time.Second) + if err != nil { + return nil, err + } + } + return NewHTTPProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil } func CreateProviders(cfg *config.Config) (map[string]LLMProvider, error) { diff --git a/pkg/providers/oauth.go b/pkg/providers/oauth.go new file mode 100644 index 0000000..9fd3e8c --- /dev/null +++ b/pkg/providers/oauth.go @@ -0,0 +1,2334 @@ +package providers + +import ( + "bufio" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync" + "time" + + "github.com/YspCoder/clawgo/pkg/config" +) + +const ( + oauthFlowCallback = "callback" + oauthFlowDevice = "device" + + oauthStyleForm = "form" + oauthStyleJSON = "json" + + defaultCodexOAuthProvider = "codex" + defaultCodexAuthURL = "https://auth.openai.com/oauth/authorize" + defaultCodexTokenURL = "https://auth.openai.com/oauth/token" + defaultCodexClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + defaultCodexCallbackPort = 1455 + defaultCodexRedirectPath = "/auth/callback" + defaultClaudeOAuthProvider = "claude" + defaultClaudeAuthURL = "https://claude.ai/oauth/authorize" + defaultClaudeTokenURL = "https://api.anthropic.com/v1/oauth/token" + defaultClaudeClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" + defaultClaudeCallbackPort = 54545 + defaultClaudeRedirectPath = "/callback" + defaultAntigravityOAuthProvider = "antigravity" + defaultAntigravityAuthURL = "https://accounts.google.com/o/oauth2/v2/auth" + defaultAntigravityTokenURL = "https://oauth2.googleapis.com/token" + defaultAntigravityCallbackPort = 51121 + defaultAntigravityRedirectPath = "/oauth-callback" + defaultGeminiOAuthProvider = "gemini" + defaultGeminiAuthURL = "https://accounts.google.com/o/oauth2/v2/auth" + defaultGeminiTokenURL = "https://oauth2.googleapis.com/token" + defaultGeminiCallbackPort = 8085 + defaultGeminiRedirectPath = "/oauth2callback" + defaultKimiOAuthProvider = "kimi" + defaultKimiDeviceCodeURL = "https://auth.kimi.com/api/oauth/device_authorization" + defaultKimiTokenURL = "https://auth.kimi.com/api/oauth/token" + defaultKimiClientID = "17e5f671-d194-4dfb-9706-5516cb48c098" + defaultQwenOAuthProvider = "qwen" + defaultQwenDeviceCodeURL = "https://chat.qwen.ai/api/v1/oauth2/device/code" + defaultQwenTokenURL = "https://chat.qwen.ai/api/v1/oauth2/token" + defaultQwenClientID = "f0304373b74a44d2b584a3fb70ca9e56" +) + +var ( + defaultCodexScopes = []string{"openid", "email", "profile", "offline_access"} + defaultClaudeScopes = []string{"org:create_api_key", "user:profile", "user:inference"} + defaultGoogleScopes = []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + } + defaultAntigravityScopes = []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/cclog", + "https://www.googleapis.com/auth/experimentsandconfigs", + } + defaultQwenScopes = []string{"openid", "profile", "email", "model.completion"} +) + +var ( + defaultAntigravityClientID = strings.TrimSpace(os.Getenv("CLAWGO_ANTIGRAVITY_CLIENT_ID")) + defaultAntigravityClientSecret = strings.TrimSpace(os.Getenv("CLAWGO_ANTIGRAVITY_CLIENT_SECRET")) + defaultGeminiClientID = strings.TrimSpace(os.Getenv("CLAWGO_GEMINI_CLIENT_ID")) + defaultGeminiClientSecret = strings.TrimSpace(os.Getenv("CLAWGO_GEMINI_CLIENT_SECRET")) +) + +var ( + defaultCodexRefreshLead = 5 * 24 * time.Hour + defaultClaudeRefreshLead = 4 * time.Hour + defaultAntigravityRefreshLead = 5 * time.Minute + defaultKimiRefreshLead = 5 * time.Minute + defaultQwenRefreshLead = 3 * time.Hour + defaultAntigravityUserInfoURL = "https://www.googleapis.com/oauth2/v1/userinfo?alt=json" + defaultGeminiUserInfoURL = "https://www.googleapis.com/oauth2/v1/userinfo?alt=json" + defaultAntigravityAPIEndpoint = "https://cloudcode-pa.googleapis.com" + defaultAntigravityAPIVersion = "v1internal" + defaultAntigravityAPIUserAgent = "google-api-nodejs-client/9.15.1" + defaultAntigravityAPIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1" + defaultAntigravityClientMeta = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}` +) + +type oauthSession struct { + Provider string `json:"provider"` + Type string `json:"type,omitempty"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + IDToken string `json:"id_token,omitempty"` + TokenType string `json:"token_type,omitempty"` + AccountID string `json:"account_id,omitempty"` + Email string `json:"email,omitempty"` + Expire string `json:"expire,omitempty"` + LastRefresh string `json:"last_refresh,omitempty"` + Models []string `json:"models,omitempty"` + ProjectID string `json:"project_id,omitempty"` + DeviceID string `json:"device_id,omitempty"` + ResourceURL string `json:"resource_url,omitempty"` + Scope string `json:"scope,omitempty"` + Token map[string]any `json:"token,omitempty"` + CooldownUntil string `json:"-"` + FailureCount int `json:"-"` + LastFailure string `json:"-"` + HealthScore int `json:"-"` + FilePath string `json:"-"` +} + +type oauthConfig struct { + Provider string + CredentialFile string + CredentialFiles []string + CallbackPort int + ClientID string + ClientSecret string + AuthURL string + TokenURL string + DeviceCodeURL string + UserInfoURL string + RedirectURL string + RedirectPath string + Scopes []string + RefreshScan time.Duration + RefreshLead time.Duration + HybridPriority string + Cooldown time.Duration + FlowKind string + TokenStyle string + DeviceGrantType string + DevicePollMax time.Duration +} + +type oauthManager struct { + providerName string + cfg oauthConfig + httpClient *http.Client + mu sync.Mutex + cached []*oauthSession + cooldowns map[string]time.Time + bgCtx context.Context + bgCancel context.CancelFunc +} + +type oauthTokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + IDToken string `json:"id_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + ResourceURL string `json:"resource_url"` + Scope string `json:"scope"` + Account struct { + UUID string `json:"uuid"` + EmailAddress string `json:"email_address"` + } `json:"account"` + Organization struct { + UUID string `json:"uuid"` + Name string `json:"name"` + } `json:"organization"` +} + +type oauthDeviceCodeResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +type OAuthLoginManager struct { + manager *oauthManager +} + +type OAuthLoginOptions struct { + Manual bool + NoBrowser bool + Reader io.Reader + AccountLabel string +} + +type OAuthPendingFlow struct { + Mode string `json:"mode,omitempty"` + State string `json:"state,omitempty"` + PKCEVerifier string `json:"pkce_verifier,omitempty"` + AuthURL string `json:"auth_url,omitempty"` + UserCode string `json:"user_code,omitempty"` + Instructions string `json:"instructions,omitempty"` + DeviceCode string `json:"device_code,omitempty"` + IntervalSec int `json:"interval_sec,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` +} + +type OAuthSessionInfo struct { + Email string + AccountID string + CredentialFile string + ProjectID string + AccountLabel string +} + +type OAuthAccountInfo struct { + Email string `json:"email"` + AccountID string `json:"account_id"` + CredentialFile string `json:"credential_file"` + Expire string `json:"expire,omitempty"` + LastRefresh string `json:"last_refresh,omitempty"` + ProjectID string `json:"project_id,omitempty"` + AccountLabel string `json:"account_label,omitempty"` + DeviceID string `json:"device_id,omitempty"` + ResourceURL string `json:"resource_url,omitempty"` + CooldownUntil string `json:"cooldown_until,omitempty"` + FailureCount int `json:"failure_count,omitempty"` + LastFailure string `json:"last_failure,omitempty"` + HealthScore int `json:"health_score,omitempty"` +} + +type oauthAttempt struct { + Session *oauthSession + Token string +} + +type oauthFailureReason string + +const ( + oauthFailureQuota oauthFailureReason = "quota" + oauthFailureRateLimit oauthFailureReason = "rate_limit" + oauthFailureForbidden oauthFailureReason = "forbidden" +) + +type oauthCallbackResult struct { + Code string + State string + Err string +} + +func NewOAuthLoginManager(pc config.ProviderConfig, timeout time.Duration) (*OAuthLoginManager, error) { + manager, err := newOAuthManager(pc, timeout) + if err != nil { + return nil, err + } + return &OAuthLoginManager{manager: manager}, nil +} + +func (m *OAuthLoginManager) Login(ctx context.Context, apiBase string, opts OAuthLoginOptions) (*OAuthSessionInfo, []string, error) { + if m == nil || m.manager == nil { + return nil, nil, fmt.Errorf("oauth login manager not configured") + } + session, models, err := m.manager.login(ctx, apiBase, opts) + if err != nil { + return nil, nil, err + } + return &OAuthSessionInfo{ + Email: session.Email, + AccountID: session.AccountID, + CredentialFile: session.FilePath, + ProjectID: session.ProjectID, + AccountLabel: sessionLabel(session), + }, models, nil +} + +func (m *OAuthLoginManager) CredentialFile() string { + if m == nil || m.manager == nil { + return "" + } + return m.manager.cfg.CredentialFile +} + +func (m *OAuthLoginManager) StartManualFlow() (*OAuthPendingFlow, error) { + if m == nil || m.manager == nil { + return nil, fmt.Errorf("oauth login manager not configured") + } + if m.manager.cfg.FlowKind == oauthFlowDevice { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + return m.manager.startDeviceFlow(ctx) + } + pkceVerifier, pkceChallenge, err := generatePKCE() + if err != nil { + return nil, err + } + state, err := randomURLToken(24) + if err != nil { + return nil, err + } + return &OAuthPendingFlow{ + Mode: oauthFlowCallback, + State: state, + PKCEVerifier: pkceVerifier, + AuthURL: m.manager.authorizationURL(state, pkceChallenge), + Instructions: "Open the authorization URL, finish login, then paste the final callback URL.", + }, nil +} + +func (m *OAuthLoginManager) CompleteManualFlow(ctx context.Context, apiBase string, flow *OAuthPendingFlow, callbackURL string) (*OAuthSessionInfo, []string, error) { + return m.CompleteManualFlowWithOptions(ctx, apiBase, flow, callbackURL, OAuthLoginOptions{}) +} + +func (m *OAuthLoginManager) CompleteManualFlowWithOptions(ctx context.Context, apiBase string, flow *OAuthPendingFlow, callbackURL string, opts OAuthLoginOptions) (*OAuthSessionInfo, []string, error) { + if m == nil || m.manager == nil { + return nil, nil, fmt.Errorf("oauth login manager not configured") + } + if flow == nil { + return nil, nil, fmt.Errorf("oauth flow is nil") + } + var ( + session *oauthSession + models []string + err error + ) + if flow.Mode == oauthFlowDevice { + session, models, err = m.manager.completeDeviceFlow(ctx, apiBase, flow, opts) + } else { + callback, parseErr := parseOAuthCallbackURL(callbackURL) + if parseErr != nil { + return nil, nil, parseErr + } + if callback.State != flow.State { + return nil, nil, fmt.Errorf("oauth callback state mismatch") + } + session, models, err = m.manager.completeLogin(ctx, apiBase, flow.PKCEVerifier, callback, flow.State, opts) + } + if err != nil { + return nil, nil, err + } + return &OAuthSessionInfo{ + Email: session.Email, + AccountID: session.AccountID, + CredentialFile: session.FilePath, + ProjectID: session.ProjectID, + AccountLabel: sessionLabel(session), + }, models, nil +} + +func (m *OAuthLoginManager) ImportAuthJSON(ctx context.Context, apiBase string, fileName string, data []byte) (*OAuthSessionInfo, []string, error) { + return m.ImportAuthJSONWithOptions(ctx, apiBase, fileName, data, OAuthLoginOptions{}) +} + +func (m *OAuthLoginManager) ImportAuthJSONWithOptions(ctx context.Context, apiBase string, fileName string, data []byte, opts OAuthLoginOptions) (*OAuthSessionInfo, []string, error) { + if m == nil || m.manager == nil { + return nil, nil, fmt.Errorf("oauth login manager not configured") + } + session, models, err := m.manager.importSession(ctx, apiBase, fileName, data, opts) + if err != nil { + return nil, nil, err + } + return &OAuthSessionInfo{ + Email: session.Email, + AccountID: session.AccountID, + CredentialFile: session.FilePath, + ProjectID: session.ProjectID, + AccountLabel: sessionLabel(session), + }, models, nil +} + +func (m *OAuthLoginManager) ListAccounts() ([]OAuthAccountInfo, error) { + if m == nil || m.manager == nil { + return nil, fmt.Errorf("oauth login manager not configured") + } + m.manager.mu.Lock() + defer m.manager.mu.Unlock() + sessions, err := m.manager.loadAllLocked() + if err != nil { + return nil, err + } + out := make([]OAuthAccountInfo, 0, len(sessions)) + for _, session := range sessions { + if session == nil { + continue + } + out = append(out, OAuthAccountInfo{ + Email: session.Email, + AccountID: session.AccountID, + CredentialFile: session.FilePath, + Expire: session.Expire, + LastRefresh: session.LastRefresh, + ProjectID: session.ProjectID, + AccountLabel: sessionLabel(session), + DeviceID: session.DeviceID, + ResourceURL: session.ResourceURL, + CooldownUntil: session.CooldownUntil, + FailureCount: session.FailureCount, + LastFailure: session.LastFailure, + HealthScore: sessionHealthScore(session), + }) + } + return out, nil +} + +func (m *OAuthLoginManager) RefreshAccount(ctx context.Context, credentialFile string) (*OAuthAccountInfo, error) { + if m == nil || m.manager == nil { + return nil, fmt.Errorf("oauth login manager not configured") + } + m.manager.mu.Lock() + defer m.manager.mu.Unlock() + sessions, err := m.manager.loadAllLocked() + if err != nil { + return nil, err + } + for _, session := range sessions { + if session == nil || strings.TrimSpace(session.FilePath) != strings.TrimSpace(credentialFile) { + continue + } + refreshed, err := m.manager.refreshSessionLocked(ctx, session) + if err != nil { + return nil, err + } + return &OAuthAccountInfo{ + Email: refreshed.Email, + AccountID: refreshed.AccountID, + CredentialFile: refreshed.FilePath, + Expire: refreshed.Expire, + LastRefresh: refreshed.LastRefresh, + ProjectID: refreshed.ProjectID, + AccountLabel: sessionLabel(refreshed), + DeviceID: refreshed.DeviceID, + ResourceURL: refreshed.ResourceURL, + CooldownUntil: refreshed.CooldownUntil, + FailureCount: refreshed.FailureCount, + LastFailure: refreshed.LastFailure, + HealthScore: sessionHealthScore(refreshed), + }, nil + } + return nil, fmt.Errorf("oauth credential not found") +} + +func (m *OAuthLoginManager) DeleteAccount(credentialFile string) error { + if m == nil || m.manager == nil { + return fmt.Errorf("oauth login manager not configured") + } + path := strings.TrimSpace(credentialFile) + if path == "" { + return fmt.Errorf("oauth credential file is empty") + } + m.manager.mu.Lock() + defer m.manager.mu.Unlock() + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return err + } + filtered := make([]*oauthSession, 0, len(m.manager.cached)) + for _, session := range m.manager.cached { + if session == nil || strings.TrimSpace(session.FilePath) == path { + continue + } + filtered = append(filtered, session) + } + m.manager.cached = filtered + delete(m.manager.cooldowns, path) + files := make([]string, 0, len(m.manager.cfg.CredentialFiles)) + for _, file := range m.manager.cfg.CredentialFiles { + if strings.TrimSpace(file) == path { + continue + } + files = append(files, file) + } + m.manager.cfg.CredentialFiles = files + if len(files) > 0 { + m.manager.cfg.CredentialFile = files[0] + } + return nil +} + +func (m *OAuthLoginManager) ClearCooldown(credentialFile string) error { + if m == nil || m.manager == nil { + return fmt.Errorf("oauth login manager not configured") + } + path := strings.TrimSpace(credentialFile) + if path == "" { + return fmt.Errorf("oauth credential file is empty") + } + m.manager.mu.Lock() + defer m.manager.mu.Unlock() + delete(m.manager.cooldowns, path) + for _, session := range m.manager.cached { + if session != nil && strings.TrimSpace(session.FilePath) == path { + session.CooldownUntil = "" + recordProviderRuntimeChange(m.manager.providerName, "oauth", firstNonEmpty(session.Email, session.AccountID, session.FilePath), "manual_clear_oauth_cooldown", "oauth cooldown cleared from runtime panel") + } + } + return nil +} + +func newOAuthManager(pc config.ProviderConfig, timeout time.Duration) (*oauthManager, error) { + resolved, err := resolveOAuthConfig(pc) + if err != nil { + return nil, err + } + bgCtx, bgCancel := context.WithCancel(context.Background()) + manager := &oauthManager{ + cfg: resolved, + httpClient: newOAuthHTTPClient(resolved.Provider, timeout), + cooldowns: map[string]time.Time{}, + bgCtx: bgCtx, + bgCancel: bgCancel, + } + manager.startBackgroundRefreshLoop() + return manager, nil +} + +func newOAuthHTTPClient(provider string, timeout time.Duration) *http.Client { + if provider == defaultClaudeOAuthProvider { + return newAnthropicOAuthHTTPClient(timeout) + } + return &http.Client{Timeout: timeout} +} + +func resolveOAuthConfig(pc config.ProviderConfig) (oauthConfig, error) { + provider := strings.ToLower(strings.TrimSpace(pc.OAuth.Provider)) + if provider == "" { + provider = strings.ToLower(strings.TrimSpace(pc.Auth)) + } + if provider == "oauth" || provider == "" { + return oauthConfig{}, fmt.Errorf("oauth provider is required") + } + provider = normalizeOAuthProvider(provider) + cfg := oauthConfig{ + Provider: provider, + CredentialFile: strings.TrimSpace(pc.OAuth.CredentialFile), + CredentialFiles: trimNonEmptyStrings(pc.OAuth.CredentialFiles), + CallbackPort: pc.OAuth.CallbackPort, + ClientID: strings.TrimSpace(pc.OAuth.ClientID), + ClientSecret: strings.TrimSpace(pc.OAuth.ClientSecret), + AuthURL: strings.TrimSpace(pc.OAuth.AuthURL), + TokenURL: strings.TrimSpace(pc.OAuth.TokenURL), + RedirectURL: strings.TrimSpace(pc.OAuth.RedirectURL), + Scopes: trimNonEmptyStrings(pc.OAuth.Scopes), + HybridPriority: normalizeHybridPriority(pc.OAuth.HybridPriority), + Cooldown: durationFromSeconds(pc.OAuth.CooldownSec, 15*time.Minute), + RefreshScan: durationFromSeconds(pc.OAuth.RefreshScanSec, 10*time.Minute), + RefreshLead: defaultRefreshLead(provider, pc.OAuth.RefreshLeadSec), + DeviceGrantType: "urn:ietf:params:oauth:grant-type:device_code", + DevicePollMax: 15 * time.Minute, + TokenStyle: oauthStyleForm, + FlowKind: oauthFlowCallback, + } + switch provider { + case defaultCodexOAuthProvider: + cfg.CallbackPort = defaultInt(cfg.CallbackPort, defaultCodexCallbackPort) + cfg.ClientID = firstNonEmpty(cfg.ClientID, defaultCodexClientID) + cfg.AuthURL = firstNonEmpty(cfg.AuthURL, defaultCodexAuthURL) + cfg.TokenURL = firstNonEmpty(cfg.TokenURL, defaultCodexTokenURL) + cfg.RedirectPath = defaultCodexRedirectPath + if len(cfg.Scopes) == 0 { + cfg.Scopes = append([]string(nil), defaultCodexScopes...) + } + case defaultClaudeOAuthProvider: + cfg.CallbackPort = defaultInt(cfg.CallbackPort, defaultClaudeCallbackPort) + cfg.ClientID = firstNonEmpty(cfg.ClientID, defaultClaudeClientID) + cfg.AuthURL = firstNonEmpty(cfg.AuthURL, defaultClaudeAuthURL) + cfg.TokenURL = firstNonEmpty(cfg.TokenURL, defaultClaudeTokenURL) + cfg.RedirectPath = defaultClaudeRedirectPath + cfg.TokenStyle = oauthStyleJSON + if len(cfg.Scopes) == 0 { + cfg.Scopes = append([]string(nil), defaultClaudeScopes...) + } + case defaultAntigravityOAuthProvider: + cfg.CallbackPort = defaultInt(cfg.CallbackPort, defaultAntigravityCallbackPort) + cfg.ClientID = firstNonEmpty(cfg.ClientID, defaultAntigravityClientID) + cfg.ClientSecret = firstNonEmpty(cfg.ClientSecret, defaultAntigravityClientSecret) + cfg.AuthURL = firstNonEmpty(cfg.AuthURL, defaultAntigravityAuthURL) + cfg.TokenURL = firstNonEmpty(cfg.TokenURL, defaultAntigravityTokenURL) + cfg.UserInfoURL = firstNonEmpty(cfg.UserInfoURL, defaultAntigravityUserInfoURL) + cfg.RedirectPath = defaultAntigravityRedirectPath + if len(cfg.Scopes) == 0 { + cfg.Scopes = append([]string(nil), defaultAntigravityScopes...) + } + case defaultGeminiOAuthProvider: + cfg.CallbackPort = defaultInt(cfg.CallbackPort, defaultGeminiCallbackPort) + cfg.ClientID = firstNonEmpty(cfg.ClientID, defaultGeminiClientID) + cfg.ClientSecret = firstNonEmpty(cfg.ClientSecret, defaultGeminiClientSecret) + cfg.AuthURL = firstNonEmpty(cfg.AuthURL, defaultGeminiAuthURL) + cfg.TokenURL = firstNonEmpty(cfg.TokenURL, defaultGeminiTokenURL) + cfg.UserInfoURL = firstNonEmpty(cfg.UserInfoURL, defaultGeminiUserInfoURL) + cfg.RedirectPath = defaultGeminiRedirectPath + if len(cfg.Scopes) == 0 { + cfg.Scopes = append([]string(nil), defaultGoogleScopes...) + } + case defaultKimiOAuthProvider: + cfg.FlowKind = oauthFlowDevice + cfg.ClientID = firstNonEmpty(cfg.ClientID, defaultKimiClientID) + cfg.DeviceCodeURL = firstNonEmpty(strings.TrimSpace(pc.OAuth.AuthURL), defaultKimiDeviceCodeURL) + cfg.TokenURL = firstNonEmpty(cfg.TokenURL, defaultKimiTokenURL) + cfg.AuthURL = cfg.DeviceCodeURL + case defaultQwenOAuthProvider: + cfg.FlowKind = oauthFlowDevice + cfg.ClientID = firstNonEmpty(cfg.ClientID, defaultQwenClientID) + cfg.DeviceCodeURL = firstNonEmpty(strings.TrimSpace(pc.OAuth.AuthURL), defaultQwenDeviceCodeURL) + cfg.TokenURL = firstNonEmpty(cfg.TokenURL, defaultQwenTokenURL) + cfg.AuthURL = cfg.DeviceCodeURL + if len(cfg.Scopes) == 0 { + cfg.Scopes = append([]string(nil), defaultQwenScopes...) + } + default: + return oauthConfig{}, fmt.Errorf("unsupported oauth provider %q", provider) + } + if cfg.FlowKind == oauthFlowCallback && cfg.RedirectURL == "" { + cfg.RedirectURL = fmt.Sprintf("http://localhost:%d%s", cfg.CallbackPort, cfg.RedirectPath) + } + if cfg.CredentialFile == "" { + cfg.CredentialFile = filepath.Join(config.GetConfigDir(), "auth", provider+".json") + } + if len(cfg.CredentialFiles) == 0 { + cfg.CredentialFiles = []string{cfg.CredentialFile} + } else { + cfg.CredentialFiles = uniqueStrings(append([]string{cfg.CredentialFile}, cfg.CredentialFiles...)) + cfg.CredentialFile = cfg.CredentialFiles[0] + } + if cfg.HybridPriority == "" { + cfg.HybridPriority = "api_first" + } + return cfg, nil +} + +func normalizeOAuthProvider(provider string) string { + switch strings.ToLower(strings.TrimSpace(provider)) { + case "anthropic": + return defaultClaudeOAuthProvider + case "gemini-cli": + return defaultGeminiOAuthProvider + default: + return strings.ToLower(strings.TrimSpace(provider)) + } +} + +func defaultRefreshLead(provider string, overrideSec int) time.Duration { + if overrideSec > 0 { + return time.Duration(overrideSec) * time.Second + } + switch strings.ToLower(strings.TrimSpace(provider)) { + case defaultCodexOAuthProvider: + return defaultCodexRefreshLead + case defaultClaudeOAuthProvider: + return defaultClaudeRefreshLead + case defaultAntigravityOAuthProvider: + return defaultAntigravityRefreshLead + case defaultKimiOAuthProvider: + return defaultKimiRefreshLead + case defaultQwenOAuthProvider: + return defaultQwenRefreshLead + default: + return 30 * time.Minute + } +} + +func (m *oauthManager) models(ctx context.Context, apiBase string) ([]string, error) { + attempts, err := m.prepareAttemptsLocked(ctx) + if err != nil { + return nil, err + } + if len(attempts) == 0 { + return nil, fmt.Errorf("oauth session not found, run `clawgo provider login` first") + } + var merged []string + seen := map[string]struct{}{} + var lastErr error + for _, attempt := range attempts { + models, err := fetchOpenAIModels(ctx, m.httpClient, apiBase, attempt.Token) + if err != nil { + lastErr = err + continue + } + attempt.Session.Models = append([]string(nil), models...) + _ = m.persistSessionLocked(attempt.Session) + for _, model := range models { + if _, ok := seen[model]; ok { + continue + } + seen[model] = struct{}{} + merged = append(merged, model) + } + } + if len(merged) > 0 { + return merged, nil + } + if lastErr != nil { + return nil, lastErr + } + return nil, fmt.Errorf("no oauth sessions available") +} + +func (m *oauthManager) login(ctx context.Context, apiBase string, opts OAuthLoginOptions) (*oauthSession, []string, error) { + if m == nil { + return nil, nil, fmt.Errorf("oauth manager not configured") + } + if m.cfg.FlowKind == oauthFlowDevice { + flow, err := m.startDeviceFlow(ctx) + if err != nil { + return nil, nil, err + } + fmt.Printf("Open this URL to continue OAuth login:\n%s\n", flow.AuthURL) + if strings.TrimSpace(flow.UserCode) != "" { + fmt.Printf("User code: %s\n", flow.UserCode) + } + if !opts.NoBrowser { + if err := openBrowser(flow.AuthURL); err != nil { + fmt.Printf("Automatic browser open failed: %v\n", err) + } + } + return m.completeDeviceFlow(ctx, apiBase, flow, opts) + } + pkceVerifier, pkceChallenge, err := generatePKCE() + if err != nil { + return nil, nil, err + } + state, err := randomURLToken(24) + if err != nil { + return nil, nil, err + } + callback, err := m.obtainOAuthCallback(ctx, state, pkceChallenge, opts) + if err != nil { + return nil, nil, err + } + if callback.State != state { + return nil, nil, fmt.Errorf("oauth callback state mismatch") + } + if callback.Err != "" { + return nil, nil, fmt.Errorf("oauth callback returned error: %s", callback.Err) + } + return m.completeLogin(ctx, apiBase, pkceVerifier, callback, state, opts) +} + +func (m *oauthManager) completeLogin(ctx context.Context, apiBase, pkceVerifier string, callback *oauthCallbackResult, state string, opts OAuthLoginOptions) (*oauthSession, []string, error) { + session, err := m.exchangeCode(ctx, callback.Code, pkceVerifier, state) + if err != nil { + return nil, nil, err + } + if err := m.applyAccountLabel(session, opts); err != nil { + return nil, nil, err + } + models, _ := fetchOpenAIModels(ctx, m.httpClient, apiBase, session.AccessToken) + if len(models) > 0 { + session.Models = append([]string(nil), models...) + } + + m.mu.Lock() + defer m.mu.Unlock() + path, err := m.allocateCredentialPathLocked(session) + if err != nil { + return nil, nil, err + } + session.FilePath = path + if err := m.persistSessionLocked(session); err != nil { + return nil, nil, err + } + m.cached = appendLoadedSession(m.cached, session) + m.cfg.CredentialFiles = uniqueStrings(append(m.cfg.CredentialFiles, path)) + m.cfg.CredentialFile = m.cfg.CredentialFiles[0] + return session, models, nil +} + +func (m *oauthManager) importSession(ctx context.Context, apiBase, fileName string, data []byte, opts OAuthLoginOptions) (*oauthSession, []string, error) { + session, err := parseImportedOAuthSession(m.cfg.Provider, fileName, data) + if err != nil { + return nil, nil, err + } + if strings.TrimSpace(session.AccessToken) == "" && strings.TrimSpace(session.RefreshToken) == "" { + return nil, nil, fmt.Errorf("auth.json missing access_token/refresh_token") + } + if strings.TrimSpace(session.AccessToken) == "" && strings.TrimSpace(session.RefreshToken) != "" { + refreshed, refreshErr := m.refreshImportedSession(ctx, session) + if refreshErr != nil { + return nil, nil, refreshErr + } + session = refreshed + } + session, err = m.enrichSession(ctx, session) + if err != nil { + return nil, nil, err + } + if err := m.applyAccountLabel(session, opts); err != nil { + return nil, nil, err + } + models, _ := fetchOpenAIModels(ctx, m.httpClient, apiBase, session.AccessToken) + if len(models) > 0 { + session.Models = append([]string(nil), models...) + } + + m.mu.Lock() + defer m.mu.Unlock() + path, err := m.allocateCredentialPathLocked(session) + if err != nil { + return nil, nil, err + } + session.FilePath = path + if err := m.persistSessionLocked(session); err != nil { + return nil, nil, err + } + m.cached = appendLoadedSession(m.cached, session) + m.cfg.CredentialFiles = uniqueStrings(append(m.cfg.CredentialFiles, path)) + m.cfg.CredentialFile = m.cfg.CredentialFiles[0] + return session, models, nil +} + +func (m *oauthManager) refreshImportedSession(ctx context.Context, session *oauthSession) (*oauthSession, error) { + if session == nil { + return nil, fmt.Errorf("oauth session is nil") + } + return m.refreshSessionData(ctx, session) +} + +func (m *oauthManager) prepareAttemptsLocked(ctx context.Context) ([]oauthAttempt, error) { + m.mu.Lock() + defer m.mu.Unlock() + + sessions, err := m.loadAllLocked() + if err != nil { + return nil, err + } + if len(sessions) == 0 { + return nil, nil + } + attempts := make([]oauthAttempt, 0, len(sessions)) + for _, session := range sessions { + if session == nil || strings.TrimSpace(session.AccessToken) == "" { + continue + } + if m.sessionOnCooldown(session) { + continue + } + if sessionNeedsRefresh(session, time.Minute) && strings.TrimSpace(session.RefreshToken) != "" { + refreshed, err := m.refreshSessionLocked(ctx, session) + if err == nil { + session = refreshed + } + } + token := strings.TrimSpace(session.AccessToken) + if token == "" { + continue + } + attempts = append(attempts, oauthAttempt{Session: session, Token: token}) + } + sortOAuthAttempts(attempts) + return attempts, nil +} + +func (m *oauthManager) markExhausted(session *oauthSession, reason oauthFailureReason) { + if session == nil { + return + } + m.mu.Lock() + defer m.mu.Unlock() + if m.cfg.Cooldown > 0 && strings.TrimSpace(session.FilePath) != "" { + until := time.Now().Add(m.cooldownForReason(reason)) + m.cooldowns[strings.TrimSpace(session.FilePath)] = until + session.CooldownUntil = until.Format(time.RFC3339) + } + session.FailureCount++ + session.LastFailure = string(reason) + if session.HealthScore == 0 { + session.HealthScore = 100 + } + session.HealthScore = maxInt(1, session.HealthScore-healthPenaltyForReason(reason)) + recordProviderRuntimeChange(m.providerName, "oauth", firstNonEmpty(session.Email, session.AccountID, session.FilePath), "oauth_cooldown_"+string(reason), "oauth credential entered cooldown after request failure") + sessions, err := m.loadAllLocked() + if err != nil { + return + } + rotated := make([]*oauthSession, 0, len(sessions)) + for _, item := range sessions { + if item == nil { + continue + } + if item.FilePath == session.FilePath { + continue + } + rotated = append(rotated, item) + } + rotated = append(rotated, session) + m.cached = rotated +} + +func (m *oauthManager) markSuccess(session *oauthSession) { + if session == nil { + return + } + m.mu.Lock() + defer m.mu.Unlock() + path := strings.TrimSpace(session.FilePath) + wasCooling := path != "" && session.CooldownUntil != "" + if path != "" { + delete(m.cooldowns, path) + session.CooldownUntil = "" + } + if session.HealthScore == 0 { + session.HealthScore = 100 + } else { + session.HealthScore = minInt(100, session.HealthScore+3) + } + if wasCooling { + recordProviderRuntimeChange(m.providerName, "oauth", firstNonEmpty(session.Email, session.AccountID, session.FilePath), "oauth_recovered", "oauth cooldown cleared after successful request") + } +} + +func (m *oauthManager) sessionOnCooldown(session *oauthSession) bool { + if m == nil || session == nil { + return false + } + path := strings.TrimSpace(session.FilePath) + if path == "" { + return false + } + until, ok := m.cooldowns[path] + if !ok { + return false + } + if time.Now().Before(until) { + session.CooldownUntil = until.Format(time.RFC3339) + return true + } + delete(m.cooldowns, path) + session.CooldownUntil = "" + return false +} + +func (m *oauthManager) cooldownForReason(reason oauthFailureReason) time.Duration { + base := m.cfg.Cooldown + switch reason { + case oauthFailureQuota: + return base * 4 + case oauthFailureForbidden: + return base * 2 + case oauthFailureRateLimit: + return base + default: + return base + } +} + +func healthPenaltyForReason(reason oauthFailureReason) int { + switch reason { + case oauthFailureQuota: + return 40 + case oauthFailureForbidden: + return 25 + case oauthFailureRateLimit: + return 10 + default: + return 10 + } +} + +func sessionHealthScore(session *oauthSession) int { + if session == nil { + return 0 + } + if session.HealthScore <= 0 { + return 100 + } + return session.HealthScore +} + +func sortOAuthAttempts(attempts []oauthAttempt) { + for i := 0; i < len(attempts); i++ { + for j := i + 1; j < len(attempts); j++ { + left := sessionHealthScore(attempts[i].Session) + right := sessionHealthScore(attempts[j].Session) + if right > left { + attempts[i], attempts[j] = attempts[j], attempts[i] + } + } + } +} + +func maxInt(a, b int) int { + if a > b { + return a + } + return b +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} + +func (m *oauthManager) credentialFiles() []string { + return uniqueStrings(append([]string(nil), m.cfg.CredentialFiles...)) +} + +func (m *oauthManager) startBackgroundRefreshLoop() { + if m == nil || m.bgCtx == nil { + return + } + go func() { + ticker := time.NewTicker(m.cfg.RefreshScan) + defer ticker.Stop() + for { + select { + case <-m.bgCtx.Done(): + return + case <-ticker.C: + _, _ = m.refreshExpiringSessions(m.bgCtx, m.cfg.RefreshLead) + } + } + }() +} + +func (m *oauthManager) refreshExpiringSessions(ctx context.Context, lead time.Duration) (*ProviderRefreshResult, error) { + if m == nil { + return &ProviderRefreshResult{}, nil + } + m.mu.Lock() + defer m.mu.Unlock() + + sessions, err := m.loadAllLocked() + if err != nil { + return nil, err + } + result := &ProviderRefreshResult{Provider: strings.TrimSpace(m.providerName)} + for _, session := range sessions { + if session == nil { + continue + } + result.Checked++ + target := firstNonEmpty(session.Email, session.AccountID, session.FilePath) + if strings.TrimSpace(session.RefreshToken) == "" { + result.Skipped++ + result.Accounts = append(result.Accounts, ProviderRefreshAccountResult{Target: target, Status: "skipped", Detail: "missing refresh_token", Expire: session.Expire}) + continue + } + if !sessionNeedsRefresh(session, lead) { + result.Skipped++ + result.Accounts = append(result.Accounts, ProviderRefreshAccountResult{Target: target, Status: "skipped", Detail: "not expiring within lead window", Expire: session.Expire}) + continue + } + refreshed, refreshErr := m.refreshSessionLocked(ctx, session) + if refreshErr != nil { + result.Failed++ + result.Accounts = append(result.Accounts, ProviderRefreshAccountResult{Target: target, Status: "failed", Detail: refreshErr.Error(), Expire: session.Expire}) + continue + } + m.cached = appendLoadedSession(m.cached, refreshed) + result.Refreshed++ + result.Accounts = append(result.Accounts, ProviderRefreshAccountResult{Target: firstNonEmpty(refreshed.Email, refreshed.AccountID, refreshed.FilePath), Status: "refreshed", Detail: "token refreshed", Expire: refreshed.Expire}) + } + return result, nil +} + +func (m *oauthManager) obtainOAuthCallback(ctx context.Context, state, pkceChallenge string, opts OAuthLoginOptions) (*oauthCallbackResult, error) { + authURL := m.authorizationURL(state, pkceChallenge) + if opts.Manual { + return waitForOAuthCodeManual(authURL, readerOrStdin(opts.Reader)) + } + return waitForOAuthCode(ctx, m.cfg.CallbackPort, m.cfg.RedirectPath, authURL, opts.NoBrowser) +} + +func (m *oauthManager) authorizationURL(state, pkceChallenge string) string { + v := url.Values{} + v.Set("client_id", m.cfg.ClientID) + v.Set("response_type", "code") + v.Set("redirect_uri", m.cfg.RedirectURL) + if len(m.cfg.Scopes) > 0 { + v.Set("scope", strings.Join(m.cfg.Scopes, " ")) + } + v.Set("state", state) + if pkceChallenge != "" { + v.Set("code_challenge", pkceChallenge) + v.Set("code_challenge_method", "S256") + } + switch m.cfg.Provider { + case defaultCodexOAuthProvider: + v.Set("prompt", "login") + v.Set("id_token_add_organizations", "true") + v.Set("codex_cli_simplified_flow", "true") + case defaultClaudeOAuthProvider: + v.Set("code", "true") + case defaultAntigravityOAuthProvider, defaultGeminiOAuthProvider: + v.Set("access_type", "offline") + v.Set("prompt", "consent") + } + return m.cfg.AuthURL + "?" + v.Encode() +} + +func (m *oauthManager) exchangeCode(ctx context.Context, code, verifier, state string) (*oauthSession, error) { + switch m.cfg.Provider { + case defaultClaudeOAuthProvider: + reqBody := map[string]any{ + "code": code, + "state": state, + "grant_type": "authorization_code", + "client_id": m.cfg.ClientID, + "redirect_uri": m.cfg.RedirectURL, + "code_verifier": verifier, + } + raw, err := m.doJSONTokenRequest(ctx, reqBody) + if err != nil { + return nil, err + } + return sessionFromTokenPayload(m.cfg.Provider, raw) + default: + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", m.cfg.ClientID) + form.Set("code", code) + form.Set("redirect_uri", m.cfg.RedirectURL) + if verifier != "" { + form.Set("code_verifier", verifier) + } + if m.cfg.ClientSecret != "" { + form.Set("client_secret", m.cfg.ClientSecret) + } + raw, err := m.doFormTokenRequest(ctx, form) + if err != nil { + return nil, err + } + session, err := sessionFromTokenPayload(m.cfg.Provider, raw) + if err != nil { + return nil, err + } + return m.enrichSession(ctx, session) + } +} + +func (m *oauthManager) refreshSessionLocked(ctx context.Context, session *oauthSession) (*oauthSession, error) { + refreshed, err := m.refreshSessionData(ctx, session) + if err != nil { + return nil, err + } + refreshed.FilePath = session.FilePath + refreshed.HealthScore = minInt(100, maxInt(sessionHealthScore(session), refreshed.HealthScore)+5) + refreshed.FailureCount = session.FailureCount + refreshed.LastFailure = session.LastFailure + refreshed.CooldownUntil = session.CooldownUntil + if err := m.persistSessionLocked(refreshed); err != nil { + return nil, err + } + m.cached = appendLoadedSession(m.cached, refreshed) + recordProviderRuntimeChange(m.providerName, "oauth", firstNonEmpty(refreshed.Email, refreshed.AccountID, refreshed.FilePath), "oauth_refresh_success", "refresh token exchanged for new access token") + return refreshed, nil +} + +func (m *oauthManager) refreshSessionData(ctx context.Context, session *oauthSession) (*oauthSession, error) { + switch m.cfg.Provider { + case defaultClaudeOAuthProvider: + raw, err := m.doJSONTokenRequest(ctx, map[string]any{ + "client_id": m.cfg.ClientID, + "grant_type": "refresh_token", + "refresh_token": session.RefreshToken, + }) + if err != nil { + return nil, err + } + refreshed, err := sessionFromTokenPayload(m.cfg.Provider, raw) + if err != nil { + return nil, err + } + return mergeOAuthSession(session, refreshed), nil + case defaultGeminiOAuthProvider: + refreshed, err := m.refreshGoogleTokenSession(ctx, session) + if err != nil { + return nil, err + } + return mergeOAuthSession(session, refreshed), nil + default: + form := url.Values{} + form.Set("client_id", m.cfg.ClientID) + form.Set("grant_type", "refresh_token") + form.Set("refresh_token", session.RefreshToken) + if m.cfg.ClientSecret != "" { + form.Set("client_secret", m.cfg.ClientSecret) + } + if len(m.cfg.Scopes) > 0 && m.cfg.Provider != defaultQwenOAuthProvider && m.cfg.Provider != defaultKimiOAuthProvider { + form.Set("scope", strings.Join(m.cfg.Scopes, " ")) + } + raw, err := m.doFormTokenRequest(ctx, form) + if err != nil { + return nil, err + } + refreshed, err := sessionFromTokenPayload(m.cfg.Provider, raw) + if err != nil { + return nil, err + } + refreshed = mergeOAuthSession(session, refreshed) + return m.enrichSession(ctx, refreshed) + } +} + +func (m *oauthManager) refreshGoogleTokenSession(ctx context.Context, session *oauthSession) (*oauthSession, error) { + tokenData := cloneStringAnyMap(session.Token) + if len(tokenData) == 0 { + tokenData = map[string]any{} + } + tokenURL := firstNonEmpty(asString(tokenData["token_uri"]), m.cfg.TokenURL, defaultGeminiTokenURL) + clientID := firstNonEmpty(asString(tokenData["client_id"]), m.cfg.ClientID) + clientSecret := firstNonEmpty(asString(tokenData["client_secret"]), m.cfg.ClientSecret) + refreshToken := firstNonEmpty(asString(tokenData["refresh_token"]), session.RefreshToken) + if tokenURL == "" || clientID == "" || refreshToken == "" { + return nil, fmt.Errorf("oauth token refresh failed: gemini token metadata incomplete") + } + form := url.Values{} + form.Set("client_id", clientID) + form.Set("grant_type", "refresh_token") + form.Set("refresh_token", refreshToken) + if clientSecret != "" { + form.Set("client_secret", clientSecret) + } + raw, err := m.doFormTokenRequestURL(ctx, tokenURL, form) + if err != nil { + return nil, err + } + refreshed, err := sessionFromTokenPayload(m.cfg.Provider, raw) + if err != nil { + return nil, err + } + tokenData["access_token"] = refreshed.AccessToken + if refreshed.RefreshToken != "" { + tokenData["refresh_token"] = refreshed.RefreshToken + } + if refreshed.IDToken != "" { + tokenData["id_token"] = refreshed.IDToken + } + if refreshed.Expire != "" { + tokenData["expiry"] = refreshed.Expire + } + tokenData["token_uri"] = tokenURL + tokenData["client_id"] = clientID + if clientSecret != "" { + tokenData["client_secret"] = clientSecret + } + if len(m.cfg.Scopes) > 0 && tokenData["scopes"] == nil { + tokenData["scopes"] = append([]string(nil), m.cfg.Scopes...) + } + refreshed.Token = tokenData + return m.enrichSession(ctx, mergeOAuthSession(session, refreshed)) +} + +func (m *oauthManager) enrichSession(ctx context.Context, session *oauthSession) (*oauthSession, error) { + if session == nil { + return nil, fmt.Errorf("oauth session is nil") + } + switch m.cfg.Provider { + case defaultAntigravityOAuthProvider, defaultGeminiOAuthProvider: + if strings.TrimSpace(session.Email) == "" && m.cfg.UserInfoURL != "" && session.AccessToken != "" { + email, err := m.fetchUserEmail(ctx, session.AccessToken) + if err == nil { + session.Email = email + } + } + if m.cfg.Provider == defaultAntigravityOAuthProvider && strings.TrimSpace(session.ProjectID) == "" && session.AccessToken != "" { + projectID, err := m.fetchAntigravityProjectID(ctx, session.AccessToken) + if err == nil { + session.ProjectID = projectID + } + } + } + return session, nil +} + +func (m *oauthManager) loadAllLocked() ([]*oauthSession, error) { + if len(m.cached) > 0 { + return cloneSessions(m.cached), nil + } + files := m.credentialFiles() + out := make([]*oauthSession, 0, len(files)) + for _, path := range files { + raw, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + continue + } + return nil, fmt.Errorf("read oauth credential file failed: %w", err) + } + session, err := parseImportedOAuthSession(m.cfg.Provider, filepath.Base(path), raw) + if err != nil { + return nil, fmt.Errorf("decode oauth credential file failed: %w", err) + } + session.FilePath = path + if until, ok := m.cooldowns[strings.TrimSpace(path)]; ok && time.Now().Before(until) { + session.CooldownUntil = until.Format(time.RFC3339) + } + out = append(out, session) + } + m.cached = cloneSessions(out) + return cloneSessions(out), nil +} + +func (m *oauthManager) persistSessionLocked(session *oauthSession) error { + if session == nil { + return fmt.Errorf("oauth session is nil") + } + m.prepareSessionForPersist(session) + path := strings.TrimSpace(session.FilePath) + if path == "" { + path = m.cfg.CredentialFile + } + if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return fmt.Errorf("create oauth credential dir failed: %w", err) + } + data, err := json.MarshalIndent(session, "", " ") + if err != nil { + return fmt.Errorf("encode oauth credential file failed: %w", err) + } + if err := os.WriteFile(path, data, 0o600); err != nil { + return fmt.Errorf("write oauth credential file failed: %w", err) + } + session.FilePath = path + return nil +} + +func (m *oauthManager) prepareSessionForPersist(session *oauthSession) { + if session == nil { + return + } + switch m.cfg.Provider { + case defaultGeminiOAuthProvider: + if session.Token == nil { + session.Token = map[string]any{} + } + if session.Token["access_token"] == nil && strings.TrimSpace(session.AccessToken) != "" { + session.Token["access_token"] = strings.TrimSpace(session.AccessToken) + } + if session.Token["refresh_token"] == nil && strings.TrimSpace(session.RefreshToken) != "" { + session.Token["refresh_token"] = strings.TrimSpace(session.RefreshToken) + } + if session.Token["id_token"] == nil && strings.TrimSpace(session.IDToken) != "" { + session.Token["id_token"] = strings.TrimSpace(session.IDToken) + } + if session.Token["expiry"] == nil && strings.TrimSpace(session.Expire) != "" { + session.Token["expiry"] = strings.TrimSpace(session.Expire) + } + if session.Token["token_uri"] == nil { + session.Token["token_uri"] = firstNonEmpty(m.cfg.TokenURL, defaultGeminiTokenURL) + } + if session.Token["client_id"] == nil { + if clientID := strings.TrimSpace(m.cfg.ClientID); clientID != "" { + session.Token["client_id"] = clientID + } + } + if session.Token["client_secret"] == nil { + if clientSecret := strings.TrimSpace(m.cfg.ClientSecret); clientSecret != "" { + session.Token["client_secret"] = clientSecret + } + } + if session.Token["scopes"] == nil { + session.Token["scopes"] = append([]string(nil), m.cfg.Scopes...) + } + if session.Token["universe_domain"] == nil { + session.Token["universe_domain"] = "googleapis.com" + } + case defaultAntigravityOAuthProvider: + if session.Token == nil { + session.Token = map[string]any{} + } + session.Token["type"] = defaultAntigravityOAuthProvider + if strings.TrimSpace(session.AccessToken) != "" { + session.Token["access_token"] = strings.TrimSpace(session.AccessToken) + } + if strings.TrimSpace(session.RefreshToken) != "" { + session.Token["refresh_token"] = strings.TrimSpace(session.RefreshToken) + } + if strings.TrimSpace(session.TokenType) != "" { + session.Token["token_type"] = strings.TrimSpace(session.TokenType) + } + if strings.TrimSpace(session.Expire) != "" { + session.Token["expired"] = strings.TrimSpace(session.Expire) + } + if strings.TrimSpace(session.Email) != "" { + session.Token["email"] = strings.TrimSpace(session.Email) + } + if strings.TrimSpace(session.ProjectID) != "" { + session.Token["project_id"] = strings.TrimSpace(session.ProjectID) + } + if strings.TrimSpace(session.Scope) != "" { + session.Token["scope"] = strings.TrimSpace(session.Scope) + } + if strings.TrimSpace(session.Email) != "" { + session.Token["account_label"] = strings.TrimSpace(session.Email) + } + } +} + +func (m *oauthManager) applyAccountLabel(session *oauthSession, opts OAuthLoginOptions) error { + if session == nil { + return fmt.Errorf("oauth session is nil") + } + label := strings.TrimSpace(opts.AccountLabel) + switch m.cfg.Provider { + case defaultQwenOAuthProvider: + if strings.TrimSpace(session.Email) == "" { + if label == "" { + return fmt.Errorf("qwen oauth requires account_label when email is unavailable") + } + session.Email = label + } + default: + if strings.TrimSpace(session.Email) == "" && label != "" { + session.Email = label + } + } + return nil +} + +func sessionLabel(session *oauthSession) string { + if session == nil { + return "" + } + return firstNonEmpty(session.Email, session.AccountID, session.ProjectID) +} + +func (m *oauthManager) allocateCredentialPathLocked(session *oauthSession) (string, error) { + files := m.credentialFiles() + if len(files) == 1 { + path := files[0] + if _, err := os.Stat(path); os.IsNotExist(err) { + return path, nil + } + } + baseDir := filepath.Dir(m.cfg.CredentialFile) + if err := os.MkdirAll(baseDir, 0o700); err != nil { + return "", err + } + label := sanitizeFileToken(firstNonEmpty(session.Email, session.AccountID, session.ProjectID, "account")) + name := fmt.Sprintf("%s-%s.json", label, time.Now().UTC().Format("20060102-150405")) + return filepath.Join(baseDir, name), nil +} + +func sessionFromTokenPayload(provider string, raw map[string]any) (*oauthSession, error) { + resp := oauthTokenResponse{} + payloadBytes, err := json.Marshal(raw) + if err == nil { + _ = json.Unmarshal(payloadBytes, &resp) + } + session := &oauthSession{ + Provider: provider, + Type: provider, + AccessToken: strings.TrimSpace(resp.AccessToken), + RefreshToken: strings.TrimSpace(resp.RefreshToken), + IDToken: strings.TrimSpace(resp.IDToken), + TokenType: strings.TrimSpace(resp.TokenType), + ResourceURL: strings.TrimSpace(resp.ResourceURL), + Scope: strings.TrimSpace(resp.Scope), + LastRefresh: time.Now().Format(time.RFC3339), + } + if resp.ExpiresIn > 0 { + session.Expire = time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339) + } + switch provider { + case defaultClaudeOAuthProvider: + session.Email = firstNonEmpty(strings.TrimSpace(resp.Account.EmailAddress), asString(raw["email"])) + session.AccountID = firstNonEmpty(strings.TrimSpace(resp.Account.UUID), strings.TrimSpace(resp.Organization.UUID)) + case defaultAntigravityOAuthProvider, defaultGeminiOAuthProvider: + session.Email = asString(raw["email"]) + session.ProjectID = asString(raw["project_id"]) + case defaultQwenOAuthProvider: + session.Email = asString(raw["email"]) + case defaultKimiOAuthProvider: + session.DeviceID = asString(raw["device_id"]) + default: + claims := parseJWTClaims(resp.IDToken) + session.Email = strings.TrimSpace(fmt.Sprintf("%v", claims["email"])) + session.AccountID = firstNonEmpty( + strings.TrimSpace(fmt.Sprintf("%v", claims["https://api.openai.com/auth"])), + strings.TrimSpace(fmt.Sprintf("%v", claims["account_id"])), + strings.TrimSpace(fmt.Sprintf("%v", claims["sub"])), + ) + } + if session.Email == "" { + session.Email = firstNonEmpty(asString(raw["email"]), asString(raw["account_id"])) + } + if session.AccountID == "" { + session.AccountID = firstNonEmpty(asString(raw["account_id"]), asString(raw["sub"])) + } + if session.Expire == "" { + session.Expire = firstNonEmpty(asString(raw["expire"]), asString(raw["expired"]), asString(raw["expiry_date"]), asString(raw["expiry"])) + } + if provider == defaultGeminiOAuthProvider { + if tokenMap := mapFromAny(raw["token"]); len(tokenMap) > 0 { + session.Token = tokenMap + session.AccessToken = firstNonEmpty(session.AccessToken, asString(tokenMap["access_token"])) + session.RefreshToken = firstNonEmpty(session.RefreshToken, asString(tokenMap["refresh_token"])) + session.IDToken = firstNonEmpty(session.IDToken, asString(tokenMap["id_token"])) + session.Expire = firstNonEmpty(session.Expire, asString(tokenMap["expiry"]), asString(tokenMap["expiry_date"])) + } + } + if session.AccessToken == "" && session.RefreshToken == "" { + return nil, fmt.Errorf("oauth token payload missing access_token/refresh_token") + } + return session, nil +} + +func mergeOAuthSession(prev, next *oauthSession) *oauthSession { + if next == nil { + return prev + } + if prev == nil { + return next + } + merged := *next + merged.Provider = firstNonEmpty(next.Provider, prev.Provider) + merged.Type = firstNonEmpty(next.Type, prev.Type) + merged.RefreshToken = firstNonEmpty(next.RefreshToken, prev.RefreshToken) + merged.Email = firstNonEmpty(next.Email, prev.Email) + merged.AccountID = firstNonEmpty(next.AccountID, prev.AccountID) + merged.ProjectID = firstNonEmpty(next.ProjectID, prev.ProjectID) + merged.DeviceID = firstNonEmpty(next.DeviceID, prev.DeviceID) + merged.ResourceURL = firstNonEmpty(next.ResourceURL, prev.ResourceURL) + merged.Scope = firstNonEmpty(next.Scope, prev.Scope) + merged.Models = append([]string(nil), prev.Models...) + if len(next.Models) > 0 { + merged.Models = append([]string(nil), next.Models...) + } + merged.FilePath = prev.FilePath + if merged.Expire == "" { + merged.Expire = prev.Expire + } + if len(next.Token) > 0 { + merged.Token = cloneStringAnyMap(next.Token) + } else if len(prev.Token) > 0 { + merged.Token = cloneStringAnyMap(prev.Token) + } + if merged.LastRefresh == "" { + merged.LastRefresh = time.Now().Format(time.RFC3339) + } + return &merged +} + +func sessionNeedsRefresh(session *oauthSession, lead time.Duration) bool { + if session == nil { + return false + } + expireAt, err := time.Parse(time.RFC3339, strings.TrimSpace(session.Expire)) + if err != nil { + return false + } + return time.Now().Add(lead).After(expireAt) +} + +func fetchOpenAIModels(ctx context.Context, client *http.Client, apiBase, token string) ([]string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpointFor(apiBase, "/models"), nil) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "application/json") + if strings.TrimSpace(token) != "" { + req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token)) + } + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("fetch models failed: %w", err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read models response failed: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("fetch models failed: status=%d body=%s", resp.StatusCode, strings.TrimSpace(string(body))) + } + var payload struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + if err := json.Unmarshal(body, &payload); err != nil { + return nil, fmt.Errorf("decode models response failed: %w", err) + } + seen := map[string]struct{}{} + out := make([]string, 0, len(payload.Data)) + for _, item := range payload.Data { + id := strings.TrimSpace(item.ID) + if id == "" { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + out = append(out, id) + } + return out, nil +} + +func waitForOAuthCode(ctx context.Context, port int, callbackPath, authURL string, noBrowser bool) (*oauthCallbackResult, error) { + if port <= 0 { + return nil, fmt.Errorf("oauth callback port must be > 0") + } + if callbackPath == "" { + callbackPath = "/auth/callback" + } + listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err != nil { + return nil, fmt.Errorf("oauth callback listener failed: %w", err) + } + defer listener.Close() + + resultCh := make(chan *oauthCallbackResult, 1) + errCh := make(chan error, 1) + server := &http.Server{Handler: http.NewServeMux()} + server.Handler.(*http.ServeMux).HandleFunc(callbackPath, func(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + if errText := strings.TrimSpace(query.Get("error")); errText != "" { + http.Error(w, "OAuth login failed: "+errText, http.StatusBadRequest) + select { + case resultCh <- &oauthCallbackResult{Err: errText, State: strings.TrimSpace(query.Get("state"))}: + default: + } + return + } + code := strings.TrimSpace(query.Get("code")) + if code == "" { + http.Error(w, "missing oauth code", http.StatusBadRequest) + select { + case errCh <- fmt.Errorf("oauth callback missing code"): + default: + } + return + } + _, _ = io.WriteString(w, "OAuth login complete. You can close this window and return to clawgo.") + select { + case resultCh <- &oauthCallbackResult{Code: code, State: strings.TrimSpace(query.Get("state"))}: + default: + } + }) + go func() { _ = server.Serve(listener) }() + + fmt.Printf("Open this URL to continue OAuth login:\n%s\n", authURL) + if !noBrowser { + if err := openBrowser(authURL); err != nil { + fmt.Printf("Automatic browser open failed: %v\n", err) + } + } + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = server.Shutdown(shutdownCtx) + }() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case err := <-errCh: + return nil, err + case result := <-resultCh: + return result, nil + } +} + +func waitForOAuthCodeManual(authURL string, reader io.Reader) (*oauthCallbackResult, error) { + fmt.Printf("Open this URL to continue OAuth login:\n%s\n", authURL) + fmt.Println("After login, copy the final callback URL from the browser and paste it below.") + br := bufio.NewReader(reader) + fmt.Print("callback_url: ") + line, err := br.ReadString('\n') + if err != nil && err != io.EOF { + return nil, fmt.Errorf("read callback url failed: %w", err) + } + line = strings.TrimSpace(line) + if line == "" { + return nil, fmt.Errorf("callback url is empty") + } + return parseOAuthCallbackURL(line) +} + +func readerOrStdin(reader io.Reader) io.Reader { + if reader != nil { + return reader + } + return os.Stdin +} + +func openBrowser(target string) error { + var cmd *exec.Cmd + switch runtime.GOOS { + case "darwin": + cmd = exec.Command("open", target) + case "windows": + cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", target) + default: + cmd = exec.Command("xdg-open", target) + } + return cmd.Start() +} + +func generatePKCE() (string, string, error) { + verifier, err := randomURLToken(64) + if err != nil { + return "", "", err + } + sum := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(sum[:]) + return verifier, challenge, nil +} + +func randomURLToken(size int) (string, error) { + if size <= 0 { + size = 32 + } + buf := make([]byte, size) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +func parseJWTClaims(token string) map[string]any { + token = strings.TrimSpace(token) + if token == "" { + return map[string]any{} + } + parts := strings.Split(token, ".") + if len(parts) < 2 { + return map[string]any{} + } + raw, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return map[string]any{} + } + var claims map[string]any + if err := json.Unmarshal(raw, &claims); err != nil { + return map[string]any{} + } + return claims +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} + +func trimNonEmptyStrings(values []string) []string { + out := make([]string, 0, len(values)) + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + out = append(out, trimmed) + } + } + return out +} + +func uniqueStrings(values []string) []string { + out := make([]string, 0, len(values)) + seen := map[string]struct{}{} + for _, value := range values { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + continue + } + if _, ok := seen[trimmed]; ok { + continue + } + seen[trimmed] = struct{}{} + out = append(out, trimmed) + } + return out +} + +func cloneSessions(values []*oauthSession) []*oauthSession { + out := make([]*oauthSession, 0, len(values)) + for _, value := range values { + if value == nil { + continue + } + cp := *value + cp.Models = append([]string(nil), value.Models...) + cp.Token = cloneStringAnyMap(value.Token) + out = append(out, &cp) + } + return out +} + +func appendLoadedSession(values []*oauthSession, session *oauthSession) []*oauthSession { + filtered := make([]*oauthSession, 0, len(values)+1) + for _, value := range values { + if value == nil || value.FilePath == session.FilePath { + continue + } + filtered = append(filtered, value) + } + filtered = append(filtered, session) + return filtered +} + +func sanitizeFileToken(value string) string { + value = strings.ToLower(strings.TrimSpace(value)) + if value == "" { + return "account" + } + var b strings.Builder + lastDash := false + for _, r := range value { + if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') { + b.WriteRune(r) + lastDash = false + continue + } + if !lastDash { + b.WriteByte('-') + lastDash = true + } + } + out := strings.Trim(b.String(), "-") + if out == "" { + return "account" + } + return out +} + +func durationFromSeconds(value int, fallback time.Duration) time.Duration { + if value <= 0 { + return fallback + } + return time.Duration(value) * time.Second +} + +func normalizeHybridPriority(value string) string { + switch strings.ToLower(strings.TrimSpace(value)) { + case "", "api_first": + return "api_first" + case "oauth_first": + return "oauth_first" + default: + return "" + } +} + +func parseOAuthCallbackURL(raw string) (*oauthCallbackResult, error) { + parsed, err := url.Parse(strings.TrimSpace(raw)) + if err != nil { + return nil, fmt.Errorf("invalid callback url: %w", err) + } + query := parsed.Query() + if errText := strings.TrimSpace(query.Get("error")); errText != "" { + return nil, fmt.Errorf("oauth callback returned error: %s", errText) + } + code := strings.TrimSpace(query.Get("code")) + if code == "" { + return nil, fmt.Errorf("oauth callback missing code") + } + return &oauthCallbackResult{ + Code: code, + State: strings.TrimSpace(query.Get("state")), + }, nil +} + +func parseImportedOAuthSession(provider, fileName string, data []byte) (*oauthSession, error) { + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + session := &oauthSession{ + Provider: provider, + Type: firstNonEmpty(asString(raw["type"]), provider), + } + if tokenMap := mapFromAny(raw["token"]); len(tokenMap) > 0 { + session.Token = tokenMap + session.AccessToken = firstNonEmpty(session.AccessToken, asString(tokenMap["access_token"])) + session.RefreshToken = firstNonEmpty(session.RefreshToken, asString(tokenMap["refresh_token"])) + session.IDToken = firstNonEmpty(session.IDToken, asString(tokenMap["id_token"])) + session.TokenType = firstNonEmpty(session.TokenType, asString(tokenMap["token_type"])) + session.Expire = firstNonEmpty(session.Expire, asString(tokenMap["expiry"]), asString(tokenMap["expiry_date"])) + if session.Token == nil { + session.Token = map[string]any{} + } + } + if authMap := mapFromAny(raw["auth"]); len(authMap) > 0 { + session.AccessToken = firstNonEmpty(session.AccessToken, asString(authMap["access_token"])) + session.RefreshToken = firstNonEmpty(session.RefreshToken, asString(authMap["refresh_token"])) + session.IDToken = firstNonEmpty(session.IDToken, asString(authMap["id_token"])) + session.TokenType = firstNonEmpty(session.TokenType, asString(authMap["token_type"])) + session.Expire = firstNonEmpty(session.Expire, asString(authMap["expiry"]), asString(authMap["expired"])) + } + session.AccessToken = firstNonEmpty(session.AccessToken, asString(raw["access_token"])) + session.RefreshToken = firstNonEmpty(session.RefreshToken, asString(raw["refresh_token"])) + session.IDToken = firstNonEmpty(session.IDToken, asString(raw["id_token"])) + session.TokenType = firstNonEmpty(session.TokenType, asString(raw["token_type"])) + session.AccountID = firstNonEmpty(asString(raw["account_id"]), asString(raw["sub"])) + session.Email = firstNonEmpty( + asString(raw["email"]), + asString(raw["account_email"]), + asString(raw["alias"]), + asString(raw["account_label"]), + asString(raw["label"]), + ) + session.Expire = firstNonEmpty(session.Expire, asString(raw["expire"]), asString(raw["expired"]), asString(raw["expiry_date"]), asString(raw["expiry"])) + session.LastRefresh = firstNonEmpty(asString(raw["last_refresh"]), time.Now().Format(time.RFC3339)) + session.ProjectID = firstNonEmpty(asString(raw["project_id"]), asString(raw["projectId"])) + session.DeviceID = firstNonEmpty(asString(raw["device_id"]), asString(raw["deviceId"])) + session.ResourceURL = asString(raw["resource_url"]) + session.Scope = firstNonEmpty(asString(raw["scope"]), asString(raw["scopes"])) + if models := stringSliceFromAny(raw["models"]); len(models) > 0 { + session.Models = models + } + if session.Token != nil { + session.Email = firstNonEmpty( + session.Email, + asString(session.Token["email"]), + asString(session.Token["alias"]), + asString(session.Token["account_label"]), + asString(session.Token["label"]), + ) + session.ProjectID = firstNonEmpty(session.ProjectID, asString(session.Token["project_id"]), asString(session.Token["projectId"])) + session.DeviceID = firstNonEmpty(session.DeviceID, asString(session.Token["device_id"]), asString(session.Token["deviceId"])) + session.Scope = firstNonEmpty(session.Scope, asString(session.Token["scope"]), asString(session.Token["scopes"])) + } + if claims := parseJWTClaims(session.IDToken); len(claims) > 0 { + session.Email = firstNonEmpty(session.Email, strings.TrimSpace(fmt.Sprintf("%v", claims["email"]))) + session.AccountID = firstNonEmpty( + session.AccountID, + strings.TrimSpace(fmt.Sprintf("%v", claims["https://api.openai.com/auth"])), + strings.TrimSpace(fmt.Sprintf("%v", claims["account_id"])), + strings.TrimSpace(fmt.Sprintf("%v", claims["sub"])), + ) + } + switch provider { + case defaultClaudeOAuthProvider: + if session.Expire == "" { + session.Expire = firstNonEmpty(asString(raw["expired"]), asString(raw["expire"])) + } + case defaultGeminiOAuthProvider: + if session.Token == nil { + if tokenMap := mapFromAny(raw["token"]); len(tokenMap) > 0 { + session.Token = tokenMap + } + } + if session.Token == nil { + session.Token = map[string]any{} + } + if session.Token["token_uri"] == nil { + session.Token["token_uri"] = defaultGeminiTokenURL + } + if session.Token["client_id"] == nil { + if clientID := strings.TrimSpace(defaultGeminiClientID); clientID != "" { + session.Token["client_id"] = clientID + } + } + if session.Token["client_secret"] == nil { + if clientSecret := strings.TrimSpace(defaultGeminiClientSecret); clientSecret != "" { + session.Token["client_secret"] = clientSecret + } + } + if session.Token["scopes"] == nil { + session.Token["scopes"] = append([]string(nil), defaultGoogleScopes...) + } + case defaultQwenOAuthProvider: + if session.Expire == "" { + session.Expire = firstNonEmpty(asString(raw["expired"]), asString(raw["expire"])) + } + case defaultKimiOAuthProvider: + if session.Expire == "" { + session.Expire = firstNonEmpty(asString(raw["expired"]), asString(raw["expire"])) + } + } + if session.AccessToken == "" && session.RefreshToken == "" { + return nil, fmt.Errorf("auth.json missing access_token/refresh_token in %s", fileName) + } + return session, nil +} + +func asString(value any) string { + switch v := value.(type) { + case string: + return strings.TrimSpace(v) + case json.Number: + return v.String() + case fmt.Stringer: + return strings.TrimSpace(v.String()) + case float64: + return strconv.FormatFloat(v, 'f', -1, 64) + case int: + return strconv.Itoa(v) + case int64: + return strconv.FormatInt(v, 10) + case bool: + if v { + return "true" + } + return "false" + default: + return "" + } +} + +func mapFromAny(value any) map[string]any { + raw, ok := value.(map[string]any) + if !ok || len(raw) == 0 { + return nil + } + out := make(map[string]any, len(raw)) + for k, v := range raw { + out[k] = v + } + return out +} + +func stringSliceFromAny(value any) []string { + items, ok := value.([]any) + if !ok { + if typed, ok := value.([]string); ok { + return append([]string(nil), typed...) + } + return nil + } + out := make([]string, 0, len(items)) + for _, item := range items { + if s := asString(item); s != "" { + out = append(out, s) + } + } + return uniqueStrings(out) +} + +func cloneStringAnyMap(in map[string]any) map[string]any { + if len(in) == 0 { + return nil + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func defaultInt(value, fallback int) int { + if value > 0 { + return value + } + return fallback +} + +func (m *oauthManager) doFormTokenRequest(ctx context.Context, form url.Values) (map[string]any, error) { + return m.doFormTokenRequestURL(ctx, m.cfg.TokenURL, form) +} + +func (m *oauthManager) doFormTokenRequestURL(ctx context.Context, endpoint string, form url.Values) (map[string]any, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + return m.doJSONRequest(req, "oauth token request") +} + +func (m *oauthManager) doJSONTokenRequest(ctx context.Context, payload map[string]any) (map[string]any, error) { + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, m.cfg.TokenURL, strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + return m.doJSONRequest(req, "oauth token request") +} + +func (m *oauthManager) doJSONRequest(req *http.Request, label string) (map[string]any, error) { + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("%s failed: %w", label, err) + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("%s read failed: %w", label, err) + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("%s failed: status=%d body=%s", label, resp.StatusCode, strings.TrimSpace(string(body))) + } + var raw map[string]any + if err := json.Unmarshal(body, &raw); err != nil { + return nil, fmt.Errorf("%s decode failed: %w", label, err) + } + return raw, nil +} + +func (m *oauthManager) fetchUserEmail(ctx context.Context, token string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, m.cfg.UserInfoURL, nil) + if err != nil { + return "", err + } + req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token)) + raw, err := m.doJSONRequest(req, "oauth userinfo request") + if err != nil { + return "", err + } + email := asString(raw["email"]) + if email == "" { + return "", fmt.Errorf("oauth userinfo missing email") + } + return email, nil +} + +func (m *oauthManager) fetchAntigravityProjectID(ctx context.Context, token string) (string, error) { + endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", defaultAntigravityAPIEndpoint, defaultAntigravityAPIVersion) + body := `{"metadata":{"ideType":"ANTIGRAVITY","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}}` + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(body)) + if err != nil { + return "", err + } + req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", defaultAntigravityAPIUserAgent) + req.Header.Set("X-Goog-Api-Client", defaultAntigravityAPIClient) + req.Header.Set("Client-Metadata", defaultAntigravityClientMeta) + resp, err := m.httpClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + raw, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return "", fmt.Errorf("antigravity project lookup failed: status=%d body=%s", resp.StatusCode, strings.TrimSpace(string(raw))) + } + var payload map[string]any + if err := json.Unmarshal(raw, &payload); err != nil { + return "", err + } + projectID := strings.TrimSpace(asString(payload["cloudaicompanionProject"])) + if projectID == "" { + if projectMap := mapFromAny(payload["cloudaicompanionProject"]); len(projectMap) > 0 { + projectID = strings.TrimSpace(asString(projectMap["id"])) + } + } + if projectID == "" { + return "", fmt.Errorf("antigravity project lookup missing project id") + } + return projectID, nil +} + +func (m *oauthManager) startDeviceFlow(ctx context.Context) (*OAuthPendingFlow, error) { + if m.cfg.FlowKind != oauthFlowDevice { + return nil, fmt.Errorf("oauth provider %s does not use device flow", m.cfg.Provider) + } + form := url.Values{} + form.Set("client_id", m.cfg.ClientID) + switch m.cfg.Provider { + case defaultQwenOAuthProvider: + verifier, challenge, err := generatePKCE() + if err != nil { + return nil, err + } + if len(m.cfg.Scopes) > 0 { + form.Set("scope", strings.Join(m.cfg.Scopes, " ")) + } + form.Set("code_challenge", challenge) + form.Set("code_challenge_method", "S256") + raw, err := m.doFormDeviceRequest(ctx, m.cfg.DeviceCodeURL, form) + if err != nil { + return nil, err + } + device, err := parseDeviceFlowPayload(raw) + if err != nil { + return nil, err + } + return &OAuthPendingFlow{ + Mode: oauthFlowDevice, + AuthURL: firstNonEmpty(device.VerificationURIComplete, device.VerificationURI), + UserCode: device.UserCode, + DeviceCode: device.DeviceCode, + PKCEVerifier: verifier, + IntervalSec: defaultInt(device.Interval, 5), + ExpiresAt: deviceExpiry(device.ExpiresIn), + Instructions: "Open the verification URL, finish authorization, then click continue to let the gateway poll for tokens.", + }, nil + case defaultKimiOAuthProvider: + raw, err := m.doFormDeviceRequest(ctx, m.cfg.DeviceCodeURL, form) + if err != nil { + return nil, err + } + device, err := parseDeviceFlowPayload(raw) + if err != nil { + return nil, err + } + return &OAuthPendingFlow{ + Mode: oauthFlowDevice, + AuthURL: firstNonEmpty(device.VerificationURIComplete, device.VerificationURI), + UserCode: device.UserCode, + DeviceCode: device.DeviceCode, + IntervalSec: defaultInt(device.Interval, 5), + ExpiresAt: deviceExpiry(device.ExpiresIn), + Instructions: "Open the verification URL, finish authorization, then click continue to let the gateway poll for tokens.", + }, nil + default: + return nil, fmt.Errorf("oauth device flow not implemented for provider %s", m.cfg.Provider) + } +} + +func (m *oauthManager) doFormDeviceRequest(ctx context.Context, endpoint string, form url.Values) (map[string]any, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + if m.cfg.Provider == defaultKimiOAuthProvider { + req.Header.Set("X-Msh-Platform", "clawgo") + req.Header.Set("X-Msh-Version", "1.0.0") + req.Header.Set("X-Msh-Device-Name", "clawgo") + req.Header.Set("X-Msh-Device-Model", runtime.GOOS+" "+runtime.GOARCH) + req.Header.Set("X-Msh-Device-Id", randomDeviceID()) + } + return m.doJSONRequest(req, "oauth device request") +} + +func parseDeviceFlowPayload(raw map[string]any) (*oauthDeviceCodeResponse, error) { + body, err := json.Marshal(raw) + if err != nil { + return nil, err + } + var device oauthDeviceCodeResponse + if err := json.Unmarshal(body, &device); err != nil { + return nil, err + } + if strings.TrimSpace(device.DeviceCode) == "" { + return nil, fmt.Errorf("oauth device flow missing device_code") + } + return &device, nil +} + +func deviceExpiry(expiresIn int) string { + if expiresIn <= 0 { + return "" + } + return time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339) +} + +func randomDeviceID() string { + token, err := randomURLToken(24) + if err != nil { + return "clawgo-device" + } + return token +} + +func (m *oauthManager) completeDeviceFlow(ctx context.Context, apiBase string, flow *OAuthPendingFlow, opts OAuthLoginOptions) (*oauthSession, []string, error) { + session, err := m.pollDeviceToken(ctx, flow) + if err != nil { + return nil, nil, err + } + if err := m.applyAccountLabel(session, opts); err != nil { + return nil, nil, err + } + models, _ := fetchOpenAIModels(ctx, m.httpClient, apiBase, session.AccessToken) + if len(models) > 0 { + session.Models = append([]string(nil), models...) + } + m.mu.Lock() + defer m.mu.Unlock() + path, err := m.allocateCredentialPathLocked(session) + if err != nil { + return nil, nil, err + } + session.FilePath = path + if err := m.persistSessionLocked(session); err != nil { + return nil, nil, err + } + m.cached = appendLoadedSession(m.cached, session) + m.cfg.CredentialFiles = uniqueStrings(append(m.cfg.CredentialFiles, path)) + m.cfg.CredentialFile = m.cfg.CredentialFiles[0] + return session, models, nil +} + +func (m *oauthManager) pollDeviceToken(ctx context.Context, flow *OAuthPendingFlow) (*oauthSession, error) { + if flow == nil || strings.TrimSpace(flow.DeviceCode) == "" { + return nil, fmt.Errorf("oauth device flow missing device code") + } + interval := time.Duration(defaultInt(flow.IntervalSec, 5)) * time.Second + deadline := time.Now().Add(m.cfg.DevicePollMax) + if expireAt, err := time.Parse(time.RFC3339, strings.TrimSpace(flow.ExpiresAt)); err == nil && expireAt.Before(deadline) { + deadline = expireAt + } + for { + if time.Now().After(deadline) { + return nil, fmt.Errorf("oauth device flow timed out") + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(interval): + } + form := url.Values{} + form.Set("grant_type", m.cfg.DeviceGrantType) + form.Set("client_id", m.cfg.ClientID) + form.Set("device_code", flow.DeviceCode) + if flow.PKCEVerifier != "" { + form.Set("code_verifier", flow.PKCEVerifier) + } + raw, err := m.doFormTokenRequest(ctx, form) + if err == nil { + session, convErr := sessionFromTokenPayload(m.cfg.Provider, raw) + if convErr != nil { + return nil, convErr + } + return m.enrichSession(ctx, session) + } + errText := strings.ToLower(err.Error()) + switch { + case strings.Contains(errText, "authorization_pending"): + continue + case strings.Contains(errText, "slow_down"): + if interval < 10*time.Second { + interval = interval + time.Second + } + continue + case strings.Contains(errText, "access_denied"): + return nil, fmt.Errorf("oauth device flow denied by user") + case strings.Contains(errText, "expired_token"): + return nil, fmt.Errorf("oauth device code expired") + default: + return nil, err + } + } +} diff --git a/pkg/providers/oauth_test.go b/pkg/providers/oauth_test.go new file mode 100644 index 0000000..532e15e --- /dev/null +++ b/pkg/providers/oauth_test.go @@ -0,0 +1,1842 @@ +package providers + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/YspCoder/clawgo/pkg/config" +) + +func TestHTTPProviderOAuthRefreshesExpiredSession(t *testing.T) { + t.Parallel() + + var refreshCalls int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth/token": + atomic.AddInt32(&refreshCalls, 1) + if err := r.ParseForm(); err != nil { + t.Fatalf("parse token form failed: %v", err) + } + if got := r.Form.Get("grant_type"); got != "refresh_token" { + t.Fatalf("unexpected grant_type: %s", got) + } + if got := r.Form.Get("refresh_token"); got != "refresh-token" { + t.Fatalf("unexpected refresh_token: %s", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"fresh-token","refresh_token":"refresh-token","expires_in":3600}`)) + case "/v1/responses": + if got := r.Header.Get("Authorization"); got != "Bearer fresh-token" { + t.Fatalf("unexpected authorization header: %s", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"status":"completed","output_text":"ok"}`)) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + credFile := filepath.Join(t.TempDir(), "codex.json") + initial := oauthSession{ + Provider: "codex", + AccessToken: "expired-token", + RefreshToken: "refresh-token", + Expire: time.Now().Add(-time.Hour).Format(time.RFC3339), + } + raw, err := json.Marshal(initial) + if err != nil { + t.Fatalf("marshal session failed: %v", err) + } + if err := os.WriteFile(credFile, raw, 0o600); err != nil { + t.Fatalf("write credential file failed: %v", err) + } + + pc := config.ProviderConfig{ + APIBase: server.URL + "/v1", + Auth: "oauth", + TimeoutSec: 5, + OAuth: config.ProviderOAuthConfig{ + Provider: "codex", + CredentialFile: credFile, + ClientID: "test-client", + TokenURL: server.URL + "/oauth/token", + AuthURL: server.URL + "/oauth/authorize", + }, + } + oauth, err := newOAuthManager(pc, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + provider := NewHTTPProvider("test-oauth", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth) + + resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil) + if err != nil { + t.Fatalf("chat failed: %v", err) + } + if resp.Content != "ok" { + t.Fatalf("unexpected chat content: %q", resp.Content) + } + if got := atomic.LoadInt32(&refreshCalls); got != 1 { + t.Fatalf("expected exactly one refresh call, got %d", got) + } + + savedRaw, err := os.ReadFile(credFile) + if err != nil { + t.Fatalf("read refreshed credential file failed: %v", err) + } + if !strings.Contains(string(savedRaw), "fresh-token") { + t.Fatalf("expected refreshed token to be persisted, got %s", string(savedRaw)) + } +} + +func TestOAuthLoginManualCallbackURLParse(t *testing.T) { + t.Parallel() + + result, err := waitForOAuthCodeManual( + "https://example.com/auth?state=test-state", + bytes.NewBufferString("http://localhost:1455/auth/callback?code=auth-code&state=test-state\n"), + ) + if err != nil { + t.Fatalf("manual callback parse failed: %v", err) + } + if result.Code != "auth-code" { + t.Fatalf("unexpected auth code: %s", result.Code) + } + if result.State != "test-state" { + t.Fatalf("unexpected state: %s", result.State) + } +} + +func TestHTTPProviderOAuthSwitchesAccountOnQuota(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + firstFile := filepath.Join(dir, "first.json") + secondFile := filepath.Join(dir, "second.json") + writeSession := func(path, token, email string) { + t.Helper() + raw, err := json.Marshal(oauthSession{ + Provider: "codex", + AccessToken: token, + Email: email, + Expire: time.Now().Add(time.Hour).Format(time.RFC3339), + }) + if err != nil { + t.Fatalf("marshal session failed: %v", err) + } + if err := os.WriteFile(path, raw, 0o600); err != nil { + t.Fatalf("write session failed: %v", err) + } + } + writeSession(firstFile, "token-a", "a@example.com") + writeSession(secondFile, "token-b", "b@example.com") + + var tokenAUsed int32 + var tokenBUsed int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/responses" { + http.NotFound(w, r) + return + } + switch r.Header.Get("Authorization") { + case "Bearer token-a": + atomic.AddInt32(&tokenAUsed, 1) + w.WriteHeader(http.StatusTooManyRequests) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"error":{"code":"insufficient_quota","message":"quota exceeded"}}`)) + case "Bearer token-b": + atomic.AddInt32(&tokenBUsed, 1) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"status":"completed","output_text":"ok-from-second"}`)) + default: + t.Fatalf("unexpected auth header: %s", r.Header.Get("Authorization")) + } + })) + defer server.Close() + + pc := config.ProviderConfig{ + APIBase: server.URL + "/v1", + Auth: "oauth", + TimeoutSec: 5, + OAuth: config.ProviderOAuthConfig{ + Provider: "codex", + CredentialFile: firstFile, + CredentialFiles: []string{firstFile, secondFile}, + ClientID: "test-client", + TokenURL: server.URL + "/oauth/token", + AuthURL: server.URL + "/oauth/authorize", + }, + } + oauth, err := newOAuthManager(pc, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + provider := NewHTTPProvider("test-oauth", "", pc.APIBase, "gpt-test", false, "oauth", 5*time.Second, oauth) + resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil) + if err != nil { + t.Fatalf("chat failed: %v", err) + } + if resp.Content != "ok-from-second" { + t.Fatalf("unexpected response content: %q", resp.Content) + } + if atomic.LoadInt32(&tokenAUsed) != 1 || atomic.LoadInt32(&tokenBUsed) != 1 { + t.Fatalf("expected one attempt per token, got token-a=%d token-b=%d", tokenAUsed, tokenBUsed) + } +} + +func TestOAuthManagerPreRefreshesExpiringSession(t *testing.T) { + t.Parallel() + + var refreshCalls int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth/token" { + http.NotFound(w, r) + return + } + atomic.AddInt32(&refreshCalls, 1) + if err := r.ParseForm(); err != nil { + t.Fatalf("parse token form failed: %v", err) + } + if got := r.Form.Get("grant_type"); got != "refresh_token" { + t.Fatalf("unexpected grant_type: %s", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"prefreshed-token","refresh_token":"refresh-token","expires_in":3600}`)) + })) + defer server.Close() + + credFile := filepath.Join(t.TempDir(), "prefresh.json") + raw, err := json.Marshal(oauthSession{ + Provider: "codex", + AccessToken: "old-token", + RefreshToken: "refresh-token", + Expire: time.Now().Add(2 * time.Minute).Format(time.RFC3339), + }) + if err != nil { + t.Fatalf("marshal session failed: %v", err) + } + if err := os.WriteFile(credFile, raw, 0o600); err != nil { + t.Fatalf("write session failed: %v", err) + } + + originalAPI := defaultAntigravityAPIEndpoint + originalUserInfo := defaultAntigravityUserInfoURL + defaultAntigravityAPIEndpoint = server.URL + defaultAntigravityUserInfoURL = server.URL + "/userinfo" + t.Cleanup(func() { + defaultAntigravityAPIEndpoint = originalAPI + defaultAntigravityUserInfoURL = originalUserInfo + }) + + manager, err := newOAuthManager(config.ProviderConfig{ + Auth: "oauth", + TimeoutSec: 5, + OAuth: config.ProviderOAuthConfig{ + Provider: "codex", + CredentialFile: credFile, + ClientID: "test-client", + TokenURL: server.URL + "/oauth/token", + AuthURL: server.URL + "/oauth/authorize", + }, + }, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + defer manager.bgCancel() + + result, err := manager.refreshExpiringSessions(context.Background(), 10*time.Minute) + if err != nil { + t.Fatalf("pre-refresh failed: %v", err) + } + if result == nil || result.Refreshed != 1 { + t.Fatalf("expected one refreshed account, got %#v", result) + } + if atomic.LoadInt32(&refreshCalls) != 1 { + t.Fatalf("expected one refresh call, got %d", refreshCalls) + } + saved, err := os.ReadFile(credFile) + if err != nil { + t.Fatalf("read saved session failed: %v", err) + } + if !strings.Contains(string(saved), "prefreshed-token") { + t.Fatalf("expected prefreshed token in file, got %s", string(saved)) + } +} + +func TestResolveOAuthConfigSupportsAdditionalProviders(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + provider string + want string + flow string + }{ + {name: "anthropic-alias", provider: "anthropic", want: "claude", flow: oauthFlowCallback}, + {name: "antigravity", provider: "antigravity", want: "antigravity", flow: oauthFlowCallback}, + {name: "gemini", provider: "gemini", want: "gemini", flow: oauthFlowCallback}, + {name: "kimi", provider: "kimi", want: "kimi", flow: oauthFlowDevice}, + {name: "qwen", provider: "qwen", want: "qwen", flow: oauthFlowDevice}, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + cfg, err := resolveOAuthConfig(config.ProviderConfig{ + Auth: "oauth", + OAuth: config.ProviderOAuthConfig{ + Provider: tc.provider, + }, + }) + if err != nil { + t.Fatalf("resolve oauth config failed: %v", err) + } + if cfg.Provider != tc.want { + t.Fatalf("unexpected provider: %s", cfg.Provider) + } + if cfg.FlowKind != tc.flow { + t.Fatalf("unexpected flow kind: %s", cfg.FlowKind) + } + }) + } +} + +func TestNewOAuthManagerUsesAnthropicTransportForClaude(t *testing.T) { + t.Parallel() + + manager, err := newOAuthManager(config.ProviderConfig{ + Auth: "oauth", + OAuth: config.ProviderOAuthConfig{ + Provider: "anthropic", + }, + }, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + defer manager.bgCancel() + + if _, ok := manager.httpClient.Transport.(*anthropicOAuthRoundTripper); !ok { + t.Fatalf("expected anthropic oauth transport, got %T", manager.httpClient.Transport) + } +} + +func TestNewOAuthManagerUsesDefaultTransportForNonClaude(t *testing.T) { + t.Parallel() + + manager, err := newOAuthManager(config.ProviderConfig{ + Auth: "oauth", + OAuth: config.ProviderOAuthConfig{ + Provider: "codex", + }, + }, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + defer manager.bgCancel() + + if manager.httpClient.Transport != nil { + t.Fatalf("expected default transport for non-claude provider, got %T", manager.httpClient.Transport) + } +} + +func TestResolveOAuthConfigAppliesProviderRefreshLeadDefaults(t *testing.T) { + t.Parallel() + + cases := []struct { + provider string + want time.Duration + }{ + {provider: "codex", want: 5 * 24 * time.Hour}, + {provider: "anthropic", want: 4 * time.Hour}, + {provider: "antigravity", want: 5 * time.Minute}, + {provider: "gemini", want: 30 * time.Minute}, + {provider: "kimi", want: 5 * time.Minute}, + {provider: "qwen", want: 3 * time.Hour}, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.provider, func(t *testing.T) { + t.Parallel() + cfg, err := resolveOAuthConfig(config.ProviderConfig{ + Auth: "oauth", + OAuth: config.ProviderOAuthConfig{ + Provider: tc.provider, + }, + }) + if err != nil { + t.Fatalf("resolve oauth config failed: %v", err) + } + if cfg.RefreshLead != tc.want { + t.Fatalf("unexpected refresh lead for %s: got %v want %v", tc.provider, cfg.RefreshLead, tc.want) + } + }) + } +} + +func TestOAuthImportGeminiNestedTokenRefreshesWithTokenMetadata(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth/token": + if err := r.ParseForm(); err != nil { + t.Fatalf("parse form failed: %v", err) + } + if got := r.Form.Get("client_secret"); got != "secret-1" { + t.Fatalf("unexpected client_secret: %s", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"gemini-fresh","refresh_token":"gemini-refresh","expires_in":3600}`)) + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"email":"gemini@example.com"}`)) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + manager, err := newOAuthManager(config.ProviderConfig{ + Auth: "oauth", + OAuth: config.ProviderOAuthConfig{ + Provider: "gemini", + TokenURL: server.URL + "/oauth/token", + ClientID: "client-1", + ClientSecret: "secret-1", + }, + }, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + manager.cfg.UserInfoURL = server.URL + "/userinfo" + + raw := []byte(`{ + "type": "gemini", + "email": "gemini@example.com", + "project_id": "demo-project", + "token": { + "refresh_token": "gemini-refresh", + "client_id": "client-1", + "client_secret": "secret-1", + "token_uri": "` + server.URL + `/oauth/token" + } + }`) + session, err := parseImportedOAuthSession("gemini", "gemini.json", raw) + if err != nil { + t.Fatalf("parse imported oauth session failed: %v", err) + } + refreshed, err := manager.refreshImportedSession(context.Background(), session) + if err != nil { + t.Fatalf("refresh imported session failed: %v", err) + } + if refreshed.AccessToken != "gemini-fresh" { + t.Fatalf("unexpected access token: %s", refreshed.AccessToken) + } + if refreshed.Email != "gemini@example.com" { + t.Fatalf("unexpected email: %s", refreshed.Email) + } + if refreshed.ProjectID != "demo-project" { + t.Fatalf("unexpected project id: %s", refreshed.ProjectID) + } +} + +func TestAntigravityEnrichSessionAddsEmailAndProjectID(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/userinfo": + if got := r.Header.Get("Authorization"); got != "Bearer antigravity-token" { + t.Fatalf("unexpected userinfo authorization: %s", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"email":"antigravity@example.com"}`)) + case "/v1internal:loadCodeAssist": + if got := r.Header.Get("Authorization"); got != "Bearer antigravity-token" { + t.Fatalf("unexpected project authorization: %s", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"cloudaicompanionProject":"project-123"}`)) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + originalAPI := defaultAntigravityAPIEndpoint + originalUserInfo := defaultAntigravityUserInfoURL + defaultAntigravityAPIEndpoint = server.URL + defaultAntigravityUserInfoURL = server.URL + "/userinfo" + t.Cleanup(func() { + defaultAntigravityAPIEndpoint = originalAPI + defaultAntigravityUserInfoURL = originalUserInfo + }) + + manager, err := newOAuthManager(config.ProviderConfig{ + Auth: "oauth", + TimeoutSec: 5, + OAuth: config.ProviderOAuthConfig{ + Provider: "antigravity", + ClientID: "client-id", + ClientSecret: "client-secret", + AuthURL: server.URL + "/oauth/authorize", + RedirectURL: "http://localhost:51121/oauth-callback", + RefreshLeadSec: 300, + }, + }, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + defer manager.bgCancel() + + session, err := manager.enrichSession(context.Background(), &oauthSession{ + Provider: "antigravity", + AccessToken: "antigravity-token", + }) + if err != nil { + t.Fatalf("enrich session failed: %v", err) + } + if session.Email != "antigravity@example.com" { + t.Fatalf("unexpected email: %#v", session) + } + if session.ProjectID != "project-123" { + t.Fatalf("expected project id enrichment, got %#v", session) + } +} + +func TestQwenDeviceFlowRequiresAccountLabelWhenEmailMissing(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/device": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"device_code":"dev-1","user_code":"user-1","verification_uri_complete":"https://chat.qwen.ai/device?code=user-1","interval":1,"expires_in":60}`)) + case "/token": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"qwen-token","refresh_token":"refresh-token","expires_in":3600}`)) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + manager, err := newOAuthManager(config.ProviderConfig{ + Auth: "oauth", + TimeoutSec: 5, + OAuth: config.ProviderOAuthConfig{ + Provider: "qwen", + AuthURL: server.URL + "/device", + TokenURL: server.URL + "/token", + CredentialFile: filepath.Join(t.TempDir(), "qwen.json"), + }, + }, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + defer manager.bgCancel() + + flow, err := manager.startDeviceFlow(context.Background()) + if err != nil { + t.Fatalf("start device flow failed: %v", err) + } + _, _, err = manager.completeDeviceFlow(context.Background(), "", flow, OAuthLoginOptions{}) + if err == nil || !strings.Contains(err.Error(), "account_label") { + t.Fatalf("expected qwen account_label error, got %v", err) + } + + session, _, err := manager.completeDeviceFlow(context.Background(), "", flow, OAuthLoginOptions{AccountLabel: "qwen-alias"}) + if err != nil { + t.Fatalf("complete device flow with label failed: %v", err) + } + if session.Email != "qwen-alias" { + t.Fatalf("expected qwen alias persisted as email label, got %#v", session) + } +} + +func TestImportSessionEnrichesAntigravityMetadata(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/userinfo": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"email":"imported@example.com"}`)) + case "/v1internal:loadCodeAssist": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"cloudaicompanionProject":"import-project"}`)) + case "/v1/models": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"data":[{"id":"g-model"}]}`)) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + originalAPI := defaultAntigravityAPIEndpoint + originalUserInfo := defaultAntigravityUserInfoURL + defaultAntigravityAPIEndpoint = server.URL + defaultAntigravityUserInfoURL = server.URL + "/userinfo" + t.Cleanup(func() { + defaultAntigravityAPIEndpoint = originalAPI + defaultAntigravityUserInfoURL = originalUserInfo + }) + + manager, err := newOAuthManager(config.ProviderConfig{ + APIBase: server.URL + "/v1", + Auth: "oauth", + TimeoutSec: 5, + OAuth: config.ProviderOAuthConfig{ + Provider: "antigravity", + CredentialFile: filepath.Join(t.TempDir(), "antigravity.json"), + }, + }, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + defer manager.bgCancel() + + raw := []byte(`{"access_token":"import-token","refresh_token":"refresh-token","expired":"2030-01-01T00:00:00Z"}`) + session, models, err := manager.importSession(context.Background(), server.URL+"/v1", "auth.json", raw, OAuthLoginOptions{}) + if err != nil { + t.Fatalf("import session failed: %v", err) + } + if session.Email != "imported@example.com" || session.ProjectID != "import-project" { + t.Fatalf("expected antigravity enrichment, got %#v", session) + } + if len(models) != 1 || models[0] != "g-model" { + t.Fatalf("unexpected models: %#v", models) + } +} + +func TestPersistSessionAddsGeminiTokenMetadata(t *testing.T) { + t.Parallel() + + credFile := filepath.Join(t.TempDir(), "gemini.json") + manager, err := newOAuthManager(config.ProviderConfig{ + Auth: "oauth", + TimeoutSec: 5, + OAuth: config.ProviderOAuthConfig{ + Provider: "gemini", + CredentialFile: credFile, + }, + }, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + defer manager.bgCancel() + + manager.mu.Lock() + err = manager.persistSessionLocked(&oauthSession{ + Provider: "gemini", + AccessToken: "gem-access", + RefreshToken: "gem-refresh", + Expire: "2030-01-01T00:00:00Z", + Email: "gem@example.com", + FilePath: credFile, + }) + manager.mu.Unlock() + if err != nil { + t.Fatalf("persist session failed: %v", err) + } + raw, err := os.ReadFile(credFile) + if err != nil { + t.Fatalf("read credential file failed: %v", err) + } + var payload map[string]any + if err := json.Unmarshal(raw, &payload); err != nil { + t.Fatalf("unmarshal credential file failed: %v", err) + } + tokenMap, _ := payload["token"].(map[string]any) + if tokenMap == nil { + t.Fatalf("expected token map in persisted gemini session, got %s", string(raw)) + } + if tokenMap["token_uri"] != defaultGeminiTokenURL { + t.Fatalf("unexpected gemini token metadata: %#v", tokenMap) + } + if defaultGeminiClientID != "" && tokenMap["client_id"] != defaultGeminiClientID { + t.Fatalf("unexpected gemini client_id metadata: %#v", tokenMap) + } + if defaultGeminiClientSecret != "" && tokenMap["client_secret"] != defaultGeminiClientSecret { + t.Fatalf("unexpected gemini client_secret metadata: %#v", tokenMap) + } +} + +func TestParseImportedOAuthSessionSupportsAliasProjectAndDeviceVariants(t *testing.T) { + t.Parallel() + + session, err := parseImportedOAuthSession("qwen", "auth.json", []byte(`{ + "refresh_token": "rt-1", + "token": { + "access_token": "at-1", + "account_label": "alias-qwen", + "projectId": "proj-1", + "deviceId": "dev-1", + "scopes": "openid profile" + } + }`)) + if err != nil { + t.Fatalf("parse imported session failed: %v", err) + } + if session.Email != "alias-qwen" || session.ProjectID != "proj-1" || session.DeviceID != "dev-1" || session.Scope != "openid profile" { + t.Fatalf("unexpected parsed session: %#v", session) + } +} + +func TestOAuthDeviceFlowQwenManualCompletes(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/device": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"device_code":"dev-1","user_code":"user-1","verification_uri_complete":"https://chat.qwen.ai/device?code=user-1","interval":1,"expires_in":60}`)) + case "/token": + if err := r.ParseForm(); err != nil { + t.Fatalf("parse form failed: %v", err) + } + if got := r.Form.Get("device_code"); got != "dev-1" { + t.Fatalf("unexpected device_code: %s", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"qwen-at","refresh_token":"qwen-rt","expires_in":3600}`)) + case "/models": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"data":[{"id":"qwen-test"}]}`)) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + dir := t.TempDir() + manager, err := newOAuthManager(config.ProviderConfig{ + APIBase: server.URL, + Auth: "oauth", + OAuth: config.ProviderOAuthConfig{ + Provider: "qwen", + CredentialFile: filepath.Join(dir, "qwen.json"), + AuthURL: server.URL + "/device", + TokenURL: server.URL + "/token", + }, + }, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + + flow, err := manager.startDeviceFlow(context.Background()) + if err != nil { + t.Fatalf("start device flow failed: %v", err) + } + if flow.Mode != oauthFlowDevice { + t.Fatalf("unexpected flow mode: %s", flow.Mode) + } + session, models, err := manager.completeDeviceFlow(context.Background(), server.URL, flow, OAuthLoginOptions{AccountLabel: "qwen-label"}) + if err != nil { + t.Fatalf("complete device flow failed: %v", err) + } + if session.AccessToken != "qwen-at" { + t.Fatalf("unexpected access token: %s", session.AccessToken) + } + if session.FilePath == "" { + t.Fatalf("expected credential file path") + } + if session.Email != "qwen-label" { + t.Fatalf("expected qwen label, got %#v", session) + } + if len(models) != 1 || models[0] != "qwen-test" { + t.Fatalf("unexpected models: %#v", models) + } +} + +func TestHTTPProviderHybridFallsBackFromAPIKeyToOAuth(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + credFile := filepath.Join(dir, "oauth.json") + raw, err := json.Marshal(oauthSession{ + Provider: "codex", + AccessToken: "oauth-token", + Email: "oauth@example.com", + Expire: time.Now().Add(time.Hour).Format(time.RFC3339), + }) + if err != nil { + t.Fatalf("marshal session failed: %v", err) + } + if err := os.WriteFile(credFile, raw, 0o600); err != nil { + t.Fatalf("write session failed: %v", err) + } + + var apiKeyCalls int32 + var oauthCalls int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/responses" { + http.NotFound(w, r) + return + } + switch r.Header.Get("Authorization") { + case "Bearer api-key-1": + atomic.AddInt32(&apiKeyCalls, 1) + w.WriteHeader(http.StatusTooManyRequests) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"error":{"code":"insufficient_quota","message":"quota exceeded"}}`)) + case "Bearer oauth-token": + atomic.AddInt32(&oauthCalls, 1) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"status":"completed","output_text":"ok-from-oauth"}`)) + default: + t.Fatalf("unexpected auth header: %s", r.Header.Get("Authorization")) + } + })) + defer server.Close() + + pc := config.ProviderConfig{ + APIBase: server.URL + "/v1", + APIKey: "api-key-1", + Auth: "hybrid", + TimeoutSec: 5, + OAuth: config.ProviderOAuthConfig{ + Provider: "codex", + CredentialFile: credFile, + TokenURL: server.URL + "/oauth/token", + AuthURL: server.URL + "/oauth/authorize", + }, + } + oauth, err := newOAuthManager(pc, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + provider := NewHTTPProvider("test-hybrid", pc.APIKey, pc.APIBase, "gpt-test", false, "hybrid", 5*time.Second, oauth) + resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil) + if err != nil { + t.Fatalf("chat failed: %v", err) + } + if resp.Content != "ok-from-oauth" { + t.Fatalf("unexpected response content: %q", resp.Content) + } + if atomic.LoadInt32(&apiKeyCalls) != 1 || atomic.LoadInt32(&oauthCalls) != 1 { + t.Fatalf("expected one api-key and one oauth attempt, got api=%d oauth=%d", apiKeyCalls, oauthCalls) + } +} + +func TestHTTPProviderHybridOAuthFirstUsesOAuthBeforeAPIKey(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + credFile := filepath.Join(dir, "oauth.json") + raw, err := json.Marshal(oauthSession{ + Provider: "codex", + AccessToken: "oauth-token", + Email: "oauth@example.com", + Expire: time.Now().Add(time.Hour).Format(time.RFC3339), + }) + if err != nil { + t.Fatalf("marshal session failed: %v", err) + } + if err := os.WriteFile(credFile, raw, 0o600); err != nil { + t.Fatalf("write session failed: %v", err) + } + + var apiKeyCalls int32 + var oauthCalls int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/responses" { + http.NotFound(w, r) + return + } + switch r.Header.Get("Authorization") { + case "Bearer api-key-1": + atomic.AddInt32(&apiKeyCalls, 1) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"status":"completed","output_text":"ok-from-api"}`)) + case "Bearer oauth-token": + atomic.AddInt32(&oauthCalls, 1) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"status":"completed","output_text":"ok-from-oauth"}`)) + default: + t.Fatalf("unexpected auth header: %s", r.Header.Get("Authorization")) + } + })) + defer server.Close() + + pc := config.ProviderConfig{ + APIBase: server.URL + "/v1", + APIKey: "api-key-1", + Auth: "hybrid", + TimeoutSec: 5, + OAuth: config.ProviderOAuthConfig{ + Provider: "codex", + CredentialFile: credFile, + TokenURL: server.URL + "/oauth/token", + AuthURL: server.URL + "/oauth/authorize", + HybridPriority: "oauth_first", + }, + } + oauth, err := newOAuthManager(pc, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + provider := NewHTTPProvider("test-hybrid", pc.APIKey, pc.APIBase, "gpt-test", false, "hybrid", 5*time.Second, oauth) + resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}}, nil, "gpt-test", nil) + if err != nil { + t.Fatalf("chat failed: %v", err) + } + if resp.Content != "ok-from-oauth" { + t.Fatalf("unexpected response content: %q", resp.Content) + } + if atomic.LoadInt32(&oauthCalls) != 1 || atomic.LoadInt32(&apiKeyCalls) != 0 { + t.Fatalf("expected oauth first only, got api=%d oauth=%d", apiKeyCalls, oauthCalls) + } +} + +func TestOAuthManagerCooldownSkipsExhaustedAccount(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + firstFile := filepath.Join(dir, "first.json") + secondFile := filepath.Join(dir, "second.json") + writeSession := func(path, token string) { + t.Helper() + raw, err := json.Marshal(oauthSession{ + Provider: "codex", + AccessToken: token, + Expire: time.Now().Add(time.Hour).Format(time.RFC3339), + }) + if err != nil { + t.Fatalf("marshal session failed: %v", err) + } + if err := os.WriteFile(path, raw, 0o600); err != nil { + t.Fatalf("write session failed: %v", err) + } + } + writeSession(firstFile, "token-a") + writeSession(secondFile, "token-b") + + manager, err := newOAuthManager(config.ProviderConfig{ + Auth: "oauth", + TimeoutSec: 5, + OAuth: config.ProviderOAuthConfig{ + Provider: "codex", + CredentialFile: firstFile, + CredentialFiles: []string{firstFile, secondFile}, + CooldownSec: 3600, + }, + }, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + + attempts, err := manager.prepareAttemptsLocked(context.Background()) + if err != nil { + t.Fatalf("prepare attempts failed: %v", err) + } + if len(attempts) != 2 { + t.Fatalf("expected 2 attempts, got %d", len(attempts)) + } + manager.markExhausted(attempts[0].Session, oauthFailureRateLimit) + nextAttempts, err := manager.prepareAttemptsLocked(context.Background()) + if err != nil { + t.Fatalf("prepare attempts after cooldown failed: %v", err) + } + if len(nextAttempts) != 1 { + t.Fatalf("expected 1 available attempt after cooldown, got %d", len(nextAttempts)) + } + if nextAttempts[0].Token != "token-b" { + t.Fatalf("unexpected token after cooldown: %s", nextAttempts[0].Token) + } + accounts, err := (&OAuthLoginManager{manager: manager}).ListAccounts() + if err != nil { + t.Fatalf("list accounts failed: %v", err) + } + foundCooldown := false + for _, account := range accounts { + if account.CredentialFile == firstFile && account.CooldownUntil != "" { + foundCooldown = true + } + } + if !foundCooldown { + t.Fatalf("expected cooldown metadata to be exposed in account list") + } +} + +func TestOAuthManagerPrefersHealthierAccount(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + firstFile := filepath.Join(dir, "first.json") + secondFile := filepath.Join(dir, "second.json") + writeSession := func(path, token string) { + t.Helper() + raw, err := json.Marshal(oauthSession{ + Provider: "codex", + AccessToken: token, + Expire: time.Now().Add(time.Hour).Format(time.RFC3339), + }) + if err != nil { + t.Fatalf("marshal session failed: %v", err) + } + if err := os.WriteFile(path, raw, 0o600); err != nil { + t.Fatalf("write session failed: %v", err) + } + } + writeSession(firstFile, "token-a") + writeSession(secondFile, "token-b") + + manager, err := newOAuthManager(config.ProviderConfig{ + Auth: "oauth", + TimeoutSec: 5, + OAuth: config.ProviderOAuthConfig{ + Provider: "codex", + CredentialFile: firstFile, + CredentialFiles: []string{firstFile, secondFile}, + CooldownSec: 60, + }, + }, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + + attempts, err := manager.prepareAttemptsLocked(context.Background()) + if err != nil { + t.Fatalf("prepare attempts failed: %v", err) + } + manager.markExhausted(attempts[0].Session, oauthFailureQuota) + delete(manager.cooldowns, attempts[0].Session.FilePath) + attempts, err = manager.prepareAttemptsLocked(context.Background()) + if err != nil { + t.Fatalf("prepare attempts after health drop failed: %v", err) + } + if len(attempts) != 2 { + t.Fatalf("expected 2 attempts, got %d", len(attempts)) + } + if attempts[0].Token != "token-b" { + t.Fatalf("expected healthier token-b first, got %s", attempts[0].Token) + } +} + +func TestClassifyOAuthFailureDifferentiatesReasons(t *testing.T) { + t.Parallel() + + reason, retry := classifyOAuthFailure(http.StatusTooManyRequests, []byte(`{"error":{"code":"insufficient_quota"}}`)) + if !retry || reason != oauthFailureQuota { + t.Fatalf("expected quota classification, got retry=%v reason=%s", retry, reason) + } + reason, retry = classifyOAuthFailure(http.StatusTooManyRequests, []byte(`{"error":{"message":"rate limit exceeded"}}`)) + if !retry || reason != oauthFailureRateLimit { + t.Fatalf("expected rate-limit classification, got retry=%v reason=%s", retry, reason) + } + reason, retry = classifyOAuthFailure(http.StatusForbidden, []byte(`{"error":"forbidden"}`)) + if !retry || reason != oauthFailureForbidden { + t.Fatalf("expected forbidden classification, got retry=%v reason=%s", retry, reason) + } +} + +func TestHTTPProviderHybridSkipsAPIKeyDuringCooldown(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + credFile := filepath.Join(dir, "oauth.json") + raw, err := json.Marshal(oauthSession{ + Provider: "codex", + AccessToken: "oauth-token", + Expire: time.Now().Add(time.Hour).Format(time.RFC3339), + }) + if err != nil { + t.Fatalf("marshal session failed: %v", err) + } + if err := os.WriteFile(credFile, raw, 0o600); err != nil { + t.Fatalf("write session failed: %v", err) + } + + providerRuntimeRegistry.mu.Lock() + providerRuntimeRegistry.api["cooldown-provider"] = providerRuntimeState{ + API: providerAPIRuntimeState{ + TokenMasked: "api-***", + HealthScore: 50, + CooldownUntil: time.Now().Add(10 * time.Minute).Format(time.RFC3339), + }, + } + providerRuntimeRegistry.mu.Unlock() + + manager, err := newOAuthManager(config.ProviderConfig{ + Auth: "hybrid", + OAuth: config.ProviderOAuthConfig{ + Provider: "codex", + CredentialFile: credFile, + }, + }, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + provider := NewHTTPProvider("cooldown-provider", "api-key-1", "https://example.com/v1", "gpt-test", false, "hybrid", 5*time.Second, manager) + attempts, err := provider.authAttempts(context.Background()) + if err != nil { + t.Fatalf("auth attempts failed: %v", err) + } + if len(attempts) != 1 || attempts[0].kind != "oauth" { + t.Fatalf("expected only oauth attempt during api cooldown, got %#v", attempts) + } +} + +func TestClearProviderAPICooldownRestoresAPIKeyAttempt(t *testing.T) { + t.Parallel() + + providerRuntimeRegistry.mu.Lock() + providerRuntimeRegistry.api["clear-api-provider"] = providerRuntimeState{ + API: providerAPIRuntimeState{ + TokenMasked: "api-***", + HealthScore: 50, + CooldownUntil: time.Now().Add(10 * time.Minute).Format(time.RFC3339), + }, + } + providerRuntimeRegistry.mu.Unlock() + + provider := NewHTTPProvider("clear-api-provider", "api-key-1", "https://example.com/v1", "gpt-test", false, "bearer", 5*time.Second, nil) + if _, err := provider.authAttempts(context.Background()); err == nil { + t.Fatalf("expected api key attempt to be blocked by cooldown") + } + ClearProviderAPICooldown("clear-api-provider") + attempts, err := provider.authAttempts(context.Background()) + if err != nil { + t.Fatalf("expected api key attempt after clear cooldown, got %v", err) + } + if len(attempts) != 1 || attempts[0].kind != "api_key" { + t.Fatalf("unexpected attempts after clear cooldown: %#v", attempts) + } +} + +func TestOAuthLoginManagerClearCooldown(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + credFile := filepath.Join(dir, "oauth.json") + raw, err := json.Marshal(oauthSession{ + Provider: "codex", + AccessToken: "oauth-token", + Expire: time.Now().Add(time.Hour).Format(time.RFC3339), + }) + if err != nil { + t.Fatalf("marshal session failed: %v", err) + } + if err := os.WriteFile(credFile, raw, 0o600); err != nil { + t.Fatalf("write session failed: %v", err) + } + manager, err := newOAuthManager(config.ProviderConfig{ + Auth: "oauth", + OAuth: config.ProviderOAuthConfig{ + Provider: "codex", + CredentialFile: credFile, + CooldownSec: 3600, + }, + }, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + attempts, err := manager.prepareAttemptsLocked(context.Background()) + if err != nil || len(attempts) != 1 { + t.Fatalf("prepare attempts failed: %v %#v", err, attempts) + } + manager.markExhausted(attempts[0].Session, oauthFailureRateLimit) + loginMgr := &OAuthLoginManager{manager: manager} + if err := loginMgr.ClearCooldown(credFile); err != nil { + t.Fatalf("clear cooldown failed: %v", err) + } + next, err := manager.prepareAttemptsLocked(context.Background()) + if err != nil || len(next) != 1 { + t.Fatalf("expected session available after cooldown clear, got err=%v attempts=%#v", err, next) + } +} + +func TestProviderRuntimeSnapshotIncludesCandidateOrderAndLastSuccess(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + credFile := filepath.Join(dir, "oauth.json") + raw, err := json.Marshal(oauthSession{ + Provider: "codex", + AccessToken: "oauth-token", + RefreshToken: "refresh-token", + Email: "user@example.com", + Expire: time.Now().Add(time.Hour).Format(time.RFC3339), + }) + if err != nil { + t.Fatalf("marshal session failed: %v", err) + } + if err := os.WriteFile(credFile, raw, 0o600); err != nil { + t.Fatalf("write session failed: %v", err) + } + + name := "runtime-snapshot-provider" + pc := config.ProviderConfig{ + APIKey: "api-key-123456", + APIBase: "https://example.com/v1", + Auth: "hybrid", + TimeoutSec: 5, + RuntimePersist: true, + RuntimeHistoryFile: filepath.Join(dir, "runtime.json"), + OAuth: config.ProviderOAuthConfig{ + Provider: "codex", + CredentialFile: credFile, + HybridPriority: "api_first", + }, + } + ConfigureProviderRuntime(name, pc) + manager, err := newOAuthManager(pc, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + provider := NewHTTPProvider(name, pc.APIKey, pc.APIBase, "gpt-test", false, pc.Auth, 5*time.Second, manager) + attempts, err := provider.authAttempts(context.Background()) + if err != nil { + t.Fatalf("auth attempts failed: %v", err) + } + if len(attempts) != 2 || attempts[0].kind != "api_key" || attempts[1].kind != "oauth" { + t.Fatalf("unexpected attempts order: %#v", attempts) + } + provider.markAttemptSuccess(attempts[1]) + + cfg := &config.Config{ + Providers: config.ProvidersConfig{ + Proxies: map[string]config.ProviderConfig{name: pc}, + }, + } + snapshot := GetProviderRuntimeSnapshot(cfg) + items, _ := snapshot["items"].([]map[string]interface{}) + if len(items) == 0 { + t.Fatalf("expected provider runtime items") + } + item := items[0] + candidates, _ := item["candidate_order"].([]providerRuntimeCandidate) + if len(candidates) < 2 { + t.Fatalf("expected candidate order, got %#v", item["candidate_order"]) + } + if candidates[0].Kind != "api_key" || candidates[1].Kind != "oauth" { + t.Fatalf("unexpected candidate order: %#v", candidates) + } + lastSuccess, _ := item["last_success"].(*providerRuntimeEvent) + if lastSuccess == nil || lastSuccess.Kind != "oauth" || lastSuccess.Target != "user@example.com" { + t.Fatalf("unexpected last success: %#v", item["last_success"]) + } + if _, err := os.Stat(pc.RuntimeHistoryFile); err != nil { + t.Fatalf("expected runtime history file, got %v", err) + } +} + +func TestConfigureProviderRuntimeLoadsPersistedEvents(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + name := "persisted-runtime-provider" + historyFile := filepath.Join(dir, "runtime.json") + payload := providerRuntimeState{ + RecentHits: []providerRuntimeEvent{{ + When: time.Now().Add(-time.Minute).Format(time.RFC3339), + Kind: "oauth", + Target: "persisted@example.com", + Reason: "ok", + }}, + LastSuccess: &providerRuntimeEvent{ + When: time.Now().Add(-time.Minute).Format(time.RFC3339), + Kind: "oauth", + Target: "persisted@example.com", + }, + } + raw, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal runtime payload failed: %v", err) + } + if err := os.WriteFile(historyFile, raw, 0o600); err != nil { + t.Fatalf("write history file failed: %v", err) + } + + ConfigureProviderRuntime(name, config.ProviderConfig{ + APIBase: "https://example.com/v1", + Auth: "bearer", + RuntimePersist: true, + RuntimeHistoryFile: historyFile, + }) + + cfg := &config.Config{ + Providers: config.ProvidersConfig{ + Proxies: map[string]config.ProviderConfig{ + name: { + APIBase: "https://example.com/v1", + Auth: "bearer", + RuntimePersist: true, + RuntimeHistoryFile: historyFile, + }, + }, + }, + } + snapshot := GetProviderRuntimeSnapshot(cfg) + items, _ := snapshot["items"].([]map[string]interface{}) + if len(items) == 0 { + t.Fatalf("expected provider runtime item") + } + lastSuccess, _ := items[0]["last_success"].(*providerRuntimeEvent) + if lastSuccess == nil || lastSuccess.Target != "persisted@example.com" { + t.Fatalf("expected persisted last success, got %#v", items[0]["last_success"]) + } + hits, _ := items[0]["recent_hits"].([]providerRuntimeEvent) + if len(hits) == 0 || hits[0].Target != "persisted@example.com" { + t.Fatalf("expected persisted recent hits, got %#v", items[0]["recent_hits"]) + } +} + +func TestClearProviderRuntimeHistoryRemovesPersistedFile(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + name := "clear-runtime-history-provider" + historyFile := filepath.Join(dir, "runtime.json") + ConfigureProviderRuntime(name, config.ProviderConfig{ + APIBase: "https://example.com/v1", + Auth: "bearer", + RuntimePersist: true, + RuntimeHistoryFile: historyFile, + }) + + providerRuntimeRegistry.mu.Lock() + state := providerRuntimeRegistry.api[name] + state.RecentHits = []providerRuntimeEvent{{When: time.Now().Format(time.RFC3339), Kind: "api_key", Target: "api***"}} + state.LastSuccess = &providerRuntimeEvent{When: time.Now().Format(time.RFC3339), Kind: "api_key", Target: "api***"} + persistProviderRuntimeLocked(name, state) + providerRuntimeRegistry.api[name] = state + providerRuntimeRegistry.mu.Unlock() + + if _, err := os.Stat(historyFile); err != nil { + t.Fatalf("expected runtime history file, got %v", err) + } + ClearProviderRuntimeHistory(name) + if _, err := os.Stat(historyFile); !os.IsNotExist(err) { + t.Fatalf("expected runtime history file removed, got %v", err) + } + + providerRuntimeRegistry.mu.Lock() + cleared := providerRuntimeRegistry.api[name] + providerRuntimeRegistry.mu.Unlock() + if len(cleared.RecentHits) != 0 || len(cleared.RecentErrors) != 0 || cleared.LastSuccess != nil { + t.Fatalf("expected runtime history cleared, got %#v", cleared) + } +} + +func TestUpdateCandidateOrderRecordsSchedulerChange(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + credFile := filepath.Join(dir, "oauth.json") + raw, err := json.Marshal(oauthSession{ + Provider: "codex", + AccessToken: "oauth-token", + Email: "user@example.com", + Expire: time.Now().Add(time.Hour).Format(time.RFC3339), + }) + if err != nil { + t.Fatalf("marshal session failed: %v", err) + } + if err := os.WriteFile(credFile, raw, 0o600); err != nil { + t.Fatalf("write session failed: %v", err) + } + + name := "candidate-change-provider" + pc := config.ProviderConfig{ + APIKey: "api-key-123456", + APIBase: "https://example.com/v1", + Auth: "hybrid", + TimeoutSec: 5, + OAuth: config.ProviderOAuthConfig{ + Provider: "codex", + CredentialFile: credFile, + HybridPriority: "api_first", + }, + } + manager, err := newOAuthManager(pc, 5*time.Second) + if err != nil { + t.Fatalf("new oauth manager failed: %v", err) + } + provider := NewHTTPProvider(name, pc.APIKey, pc.APIBase, "gpt-test", false, pc.Auth, 5*time.Second, manager) + attempts, err := provider.authAttempts(context.Background()) + if err != nil { + t.Fatalf("auth attempts failed: %v", err) + } + if len(attempts) != 2 { + t.Fatalf("unexpected attempts: %#v", attempts) + } + provider.markAPIKeyFailure(oauthFailureRateLimit) + attempts, err = provider.authAttempts(context.Background()) + if err != nil { + t.Fatalf("auth attempts after cooldown failed: %v", err) + } + if len(attempts) != 1 || attempts[0].kind != "oauth" { + t.Fatalf("unexpected attempts after cooldown: %#v", attempts) + } + + providerRuntimeRegistry.mu.Lock() + state := providerRuntimeRegistry.api[name] + providerRuntimeRegistry.mu.Unlock() + if len(state.RecentChanges) == 0 || state.RecentChanges[0].Reason != "candidate_order_changed" { + t.Fatalf("expected scheduler change event, got %#v", state.RecentChanges) + } + if !strings.Contains(state.RecentChanges[0].Detail, "top ") { + t.Fatalf("expected candidate order detail, got %#v", state.RecentChanges[0]) + } +} + +func TestGetProviderRuntimeViewFiltersEvents(t *testing.T) { + t.Parallel() + + name := "runtime-view-provider" + providerRuntimeRegistry.mu.Lock() + providerRuntimeRegistry.api[name] = providerRuntimeState{ + RecentHits: []providerRuntimeEvent{ + {When: time.Now().Add(-30 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "user@example.com", Reason: "ok"}, + {When: time.Now().Add(-3 * time.Hour).Format(time.RFC3339), Kind: "api_key", Target: "api***", Reason: "ok"}, + }, + RecentErrors: []providerRuntimeEvent{ + {When: time.Now().Add(-10 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "user@example.com", Reason: "quota"}, + }, + RecentChanges: []providerRuntimeEvent{ + {When: time.Now().Add(-5 * time.Minute).Format(time.RFC3339), Kind: "scheduler", Target: name, Reason: "candidate_order_changed", Detail: "top api -> oauth"}, + }, + } + providerRuntimeRegistry.mu.Unlock() + + cfg := &config.Config{ + Providers: config.ProvidersConfig{ + Proxies: map[string]config.ProviderConfig{ + name: {APIBase: "https://example.com/v1", Auth: "hybrid", APIKey: "api-key"}, + }, + }, + } + view := GetProviderRuntimeView(cfg, ProviderRuntimeQuery{ + Provider: name, + Window: 2 * time.Hour, + EventKind: "oauth", + Limit: 1, + }) + items, _ := view["items"].([]map[string]interface{}) + if len(items) != 1 { + t.Fatalf("expected one runtime item, got %#v", view) + } + hits, _ := items[0]["recent_hits"].([]providerRuntimeEvent) + if len(hits) != 1 || hits[0].Kind != "oauth" { + t.Fatalf("expected filtered oauth hits, got %#v", items[0]["recent_hits"]) + } + errors, _ := items[0]["recent_errors"].([]providerRuntimeEvent) + if len(errors) != 1 || errors[0].Reason != "quota" { + t.Fatalf("expected filtered oauth errors, got %#v", items[0]["recent_errors"]) + } + changes, _ := items[0]["recent_changes"].([]providerRuntimeEvent) + if len(changes) != 0 { + t.Fatalf("expected no scheduler changes when filtering kind=oauth, got %#v", items[0]["recent_changes"]) + } + events, _ := items[0]["events"].([]providerRuntimeEvent) + if len(events) != 1 { + t.Fatalf("expected merged paged events, got %#v", items[0]["events"]) + } +} + +func TestGetProviderRuntimeViewCursorPagination(t *testing.T) { + t.Parallel() + + name := "runtime-cursor-provider" + now := time.Now() + providerRuntimeRegistry.mu.Lock() + providerRuntimeRegistry.api[name] = providerRuntimeState{ + RecentHits: []providerRuntimeEvent{ + {When: now.Add(-1 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "a", Reason: "ok"}, + {When: now.Add(-2 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "b", Reason: "ok"}, + }, + RecentErrors: []providerRuntimeEvent{ + {When: now.Add(-3 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "c", Reason: "quota"}, + }, + } + providerRuntimeRegistry.mu.Unlock() + cfg := &config.Config{ + Providers: config.ProvidersConfig{ + Proxies: map[string]config.ProviderConfig{ + name: {APIBase: "https://example.com/v1", Auth: "oauth", OAuth: config.ProviderOAuthConfig{Provider: "codex"}}, + }, + }, + } + view := GetProviderRuntimeView(cfg, ProviderRuntimeQuery{Provider: name, Limit: 2, Cursor: 0}) + items, _ := view["items"].([]map[string]interface{}) + if len(items) != 1 { + t.Fatalf("expected one item, got %#v", view) + } + page1, _ := items[0]["events"].([]providerRuntimeEvent) + if len(page1) != 2 || items[0]["next_cursor"].(int) != 2 { + t.Fatalf("unexpected first page %#v", items[0]) + } + view = GetProviderRuntimeView(cfg, ProviderRuntimeQuery{Provider: name, Limit: 2, Cursor: 2}) + items, _ = view["items"].([]map[string]interface{}) + page2, _ := items[0]["events"].([]providerRuntimeEvent) + if len(page2) != 1 || items[0]["next_cursor"].(int) != 0 { + t.Fatalf("unexpected second page %#v", items[0]) + } +} + +func TestGetProviderRuntimeViewSortAscending(t *testing.T) { + t.Parallel() + + name := "runtime-sort-provider" + now := time.Now() + providerRuntimeRegistry.mu.Lock() + providerRuntimeRegistry.api[name] = providerRuntimeState{ + RecentHits: []providerRuntimeEvent{ + {When: now.Add(-1 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "a", Reason: "ok"}, + {When: now.Add(-3 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "b", Reason: "ok"}, + }, + RecentErrors: []providerRuntimeEvent{ + {When: now.Add(-2 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "c", Reason: "quota"}, + }, + } + providerRuntimeRegistry.mu.Unlock() + cfg := &config.Config{ + Providers: config.ProvidersConfig{ + Proxies: map[string]config.ProviderConfig{ + name: {APIBase: "https://example.com/v1", Auth: "oauth", OAuth: config.ProviderOAuthConfig{Provider: "codex"}}, + }, + }, + } + view := GetProviderRuntimeView(cfg, ProviderRuntimeQuery{Provider: name, Limit: 10, Sort: "asc"}) + items, _ := view["items"].([]map[string]interface{}) + if len(items) != 1 { + t.Fatalf("expected one item, got %#v", view) + } + events, _ := items[0]["events"].([]providerRuntimeEvent) + if len(events) != 3 { + t.Fatalf("expected three events, got %#v", items[0]["events"]) + } + if events[0].Target != "b" || events[1].Target != "c" || events[2].Target != "a" { + t.Fatalf("expected ascending order oldest->newest, got %#v", events) + } +} + +func TestGetProviderRuntimeViewFiltersByHealthAndCooldown(t *testing.T) { + t.Parallel() + + name := "runtime-health-provider" + providerRuntimeRegistry.mu.Lock() + providerRuntimeRegistry.api[name] = providerRuntimeState{ + API: providerAPIRuntimeState{ + HealthScore: 20, + CooldownUntil: time.Now().Add(-5 * time.Minute).Format(time.RFC3339), + }, + CandidateOrder: []providerRuntimeCandidate{{ + Kind: "api_key", + Target: "api***", + HealthScore: 20, + CooldownUntil: time.Now().Add(-5 * time.Minute).Format(time.RFC3339), + }}, + } + providerRuntimeRegistry.mu.Unlock() + cfg := &config.Config{ + Providers: config.ProvidersConfig{ + Proxies: map[string]config.ProviderConfig{ + name: {APIBase: "https://example.com/v1", Auth: "bearer", APIKey: "api-key"}, + }, + }, + } + view := GetProviderRuntimeView(cfg, ProviderRuntimeQuery{ + Provider: name, + HealthBelow: 30, + CooldownBefore: time.Now(), + }) + items, _ := view["items"].([]map[string]interface{}) + if len(items) != 1 { + t.Fatalf("expected one filtered runtime item, got %#v", view) + } +} + +func TestGetProviderRuntimeSummaryFlagsUnhealthyProviders(t *testing.T) { + t.Parallel() + + name := "runtime-summary-provider" + lastSuccessAt := time.Now().Add(-2 * time.Hour).Format(time.RFC3339) + topChangedAt := time.Now().Add(-3 * time.Minute).Format(time.RFC3339) + providerRuntimeRegistry.mu.Lock() + providerRuntimeRegistry.api[name] = providerRuntimeState{ + API: providerAPIRuntimeState{ + HealthScore: 25, + CooldownUntil: time.Now().Add(15 * time.Minute).Format(time.RFC3339), + }, + RecentErrors: []providerRuntimeEvent{ + {When: time.Now().Add(-5 * time.Minute).Format(time.RFC3339), Kind: "api_key", Target: "api***", Reason: "quota"}, + }, + RecentChanges: []providerRuntimeEvent{ + {When: topChangedAt, Kind: "scheduler", Target: name, Reason: "candidate_order_changed", Detail: "top oauth -> api"}, + }, + LastSuccess: &providerRuntimeEvent{ + When: lastSuccessAt, + Kind: "api_key", + Target: "api***", + Reason: "ok", + }, + CandidateOrder: []providerRuntimeCandidate{{ + Kind: "api_key", + Target: "api***", + Available: false, + Status: "cooldown", + HealthScore: 25, + CooldownUntil: time.Now().Add(15 * time.Minute).Format(time.RFC3339), + }}, + } + providerRuntimeRegistry.mu.Unlock() + cfg := &config.Config{ + Providers: config.ProvidersConfig{ + Proxies: map[string]config.ProviderConfig{ + name: {APIBase: "https://example.com/v1", Auth: "bearer", APIKey: "api-key"}, + }, + }, + } + summary := GetProviderRuntimeSummary(cfg, ProviderRuntimeQuery{HealthBelow: 30, Window: time.Hour}) + if summary.TotalProviders != 1 || summary.InCooldown != 1 || summary.LowHealth != 1 || summary.RecentErrors != 1 { + t.Fatalf("unexpected summary counts: %#v", summary) + } + if summary.Critical != 1 || summary.Degraded != 0 || summary.Healthy != 0 { + t.Fatalf("unexpected status counts: %#v", summary) + } + if len(summary.Providers) != 1 || summary.Providers[0].TopCandidate == nil || summary.Providers[0].TopCandidate.Kind != "api_key" { + t.Fatalf("unexpected provider summary items: %#v", summary.Providers) + } + if summary.Providers[0].Status != "critical" { + t.Fatalf("expected critical status, got %#v", summary.Providers[0]) + } + if summary.Providers[0].LastError == nil || summary.Providers[0].LastErrorReason != "quota" || summary.Providers[0].LastErrorAt == "" { + t.Fatalf("expected last error details, got %#v", summary.Providers[0]) + } + if summary.Providers[0].LastSuccessAt != lastSuccessAt || summary.Providers[0].TopCandidateChangedAt != topChangedAt { + t.Fatalf("expected last success and top candidate timestamps, got %#v", summary.Providers[0]) + } + if summary.Providers[0].StaleForSec < 7100 || summary.Providers[0].StaleForSec > 7300 { + t.Fatalf("expected stale_for_sec around 2h, got %#v", summary.Providers[0].StaleForSec) + } +} + +func TestGetProviderRuntimeSummaryMarksRecentErrorsAsDegraded(t *testing.T) { + t.Parallel() + + name := "runtime-summary-degraded-provider" + providerRuntimeRegistry.mu.Lock() + providerRuntimeRegistry.api[name] = providerRuntimeState{ + RecentErrors: []providerRuntimeEvent{ + {When: time.Now().Add(-10 * time.Minute).Format(time.RFC3339), Kind: "oauth", Target: "user@example.com", Reason: "rate_limit"}, + }, + CandidateOrder: []providerRuntimeCandidate{{ + Kind: "oauth", + Target: "user@example.com", + Available: true, + Status: "ready", + HealthScore: 90, + }}, + } + providerRuntimeRegistry.mu.Unlock() + cfg := &config.Config{ + Providers: config.ProvidersConfig{ + Proxies: map[string]config.ProviderConfig{ + name: {APIBase: "https://example.com/v1", Auth: "oauth", OAuth: config.ProviderOAuthConfig{Provider: "codex"}}, + }, + }, + } + + summary := GetProviderRuntimeSummary(cfg, ProviderRuntimeQuery{HealthBelow: 30, Window: time.Hour}) + if summary.TotalProviders != 1 || summary.Degraded != 1 || summary.Critical != 0 || summary.Healthy != 0 { + t.Fatalf("unexpected summary counts: %#v", summary) + } + if len(summary.Providers) != 1 || summary.Providers[0].Status != "degraded" { + t.Fatalf("expected degraded provider item, got %#v", summary.Providers) + } + if summary.Providers[0].LastErrorReason != "rate_limit" { + t.Fatalf("expected last error reason, got %#v", summary.Providers[0]) + } + if summary.Providers[0].StaleForSec != -1 { + t.Fatalf("expected stale_for_sec=-1 without success event, got %#v", summary.Providers[0]) + } +} + +func TestGetProviderRuntimeSummaryIncludesOAuthAccountMetadata(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + credFile := filepath.Join(dir, "qwen.json") + raw, err := json.Marshal(oauthSession{ + Provider: "qwen", + AccessToken: "qwen-token", + RefreshToken: "refresh-token", + Email: "qwen-label", + ProjectID: "proj-9", + DeviceID: "device-9", + ResourceURL: "https://chat.qwen.ai/api", + Expire: time.Now().Add(time.Hour).Format(time.RFC3339), + FilePath: credFile, + }) + if err != nil { + t.Fatalf("marshal session failed: %v", err) + } + if err := os.WriteFile(credFile, raw, 0o600); err != nil { + t.Fatalf("write session failed: %v", err) + } + cfg := &config.Config{ + Providers: config.ProvidersConfig{ + Proxies: map[string]config.ProviderConfig{ + "qwen-summary": { + APIBase: "https://example.com/v1", + Auth: "oauth", + TimeoutSec: 5, + OAuth: config.ProviderOAuthConfig{ + Provider: "qwen", + CredentialFile: credFile, + }, + }, + }, + }, + } + summary := GetProviderRuntimeSummary(cfg, ProviderRuntimeQuery{Provider: "qwen-summary", HealthBelow: 50}) + if len(summary.Providers) != 1 { + t.Fatalf("expected one provider, got %#v", summary) + } + if len(summary.Providers[0].OAuthAccounts) != 1 { + t.Fatalf("expected oauth account metadata, got %#v", summary.Providers[0]) + } + account := summary.Providers[0].OAuthAccounts[0] + if account.AccountLabel != "qwen-label" || account.ProjectID != "proj-9" || account.DeviceID != "device-9" || account.ResourceURL != "https://chat.qwen.ai/api" { + t.Fatalf("unexpected oauth account metadata: %#v", account) + } +} + +func TestRefreshProviderRuntimeNowSupportsOnlyExpiring(t *testing.T) { + t.Parallel() + + var refreshCalls int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth/token" { + http.NotFound(w, r) + return + } + atomic.AddInt32(&refreshCalls, 1) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"refreshed-token","refresh_token":"refresh-token","expires_in":3600}`)) + })) + defer server.Close() + + credFile := filepath.Join(t.TempDir(), "codex.json") + raw, err := json.Marshal(oauthSession{ + Provider: "codex", + AccessToken: "old-token", + RefreshToken: "refresh-token", + Expire: time.Now().Add(24 * time.Hour).Format(time.RFC3339), + }) + if err != nil { + t.Fatalf("marshal session failed: %v", err) + } + if err := os.WriteFile(credFile, raw, 0o600); err != nil { + t.Fatalf("write credential file failed: %v", err) + } + + name := "runtime-refresh-provider" + cfg := &config.Config{ + Providers: config.ProvidersConfig{ + Proxies: map[string]config.ProviderConfig{ + name: { + APIBase: server.URL + "/v1", + Auth: "oauth", + TimeoutSec: 5, + OAuth: config.ProviderOAuthConfig{ + Provider: "codex", + CredentialFile: credFile, + ClientID: "test-client", + TokenURL: server.URL + "/oauth/token", + AuthURL: server.URL + "/oauth/authorize", + RefreshLeadSec: 1800, + }, + }, + }, + }, + } + + result, err := RefreshProviderRuntimeNow(cfg, name, true) + if err != nil { + t.Fatalf("refresh only expiring failed: %v", err) + } + if result == nil || result.Refreshed != 0 || result.Skipped != 1 { + t.Fatalf("expected skip for non-expiring session, got %#v", result) + } + if atomic.LoadInt32(&refreshCalls) != 0 { + t.Fatalf("expected no refresh calls for only-expiring path, got %d", refreshCalls) + } + + result, err = RefreshProviderRuntimeNow(cfg, name, false) + if err != nil { + t.Fatalf("refresh all failed: %v", err) + } + if result == nil || result.Refreshed != 1 { + t.Fatalf("expected forced refresh, got %#v", result) + } + if atomic.LoadInt32(&refreshCalls) != 1 { + t.Fatalf("expected one refresh call, got %d", refreshCalls) + } +} + +func TestRerankProviderRuntimeUpdatesCandidateOrder(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + credFile := filepath.Join(dir, "oauth.json") + raw, err := json.Marshal(oauthSession{ + Provider: "codex", + AccessToken: "oauth-token", + Email: "rerank@example.com", + Expire: time.Now().Add(time.Hour).Format(time.RFC3339), + }) + if err != nil { + t.Fatalf("marshal session failed: %v", err) + } + if err := os.WriteFile(credFile, raw, 0o600); err != nil { + t.Fatalf("write session failed: %v", err) + } + name := "rerank-runtime-provider" + cfg := &config.Config{ + Providers: config.ProvidersConfig{ + Proxies: map[string]config.ProviderConfig{ + name: { + APIKey: "api-key", + APIBase: "https://example.com/v1", + Auth: "hybrid", + TimeoutSec: 5, + OAuth: config.ProviderOAuthConfig{ + Provider: "codex", + CredentialFile: credFile, + HybridPriority: "oauth_first", + }, + }, + }, + }, + } + order, err := RerankProviderRuntime(cfg, name) + if err != nil { + t.Fatalf("rerank provider runtime failed: %v", err) + } + if len(order) == 0 || order[0].Kind != "oauth" { + t.Fatalf("expected oauth-first rerank result, got %#v", order) + } + snapshot := GetProviderRuntimeSnapshot(cfg) + items, _ := snapshot["items"].([]map[string]interface{}) + if len(items) != 1 { + t.Fatalf("expected one runtime item, got %#v", snapshot) + } + snapshotOrder, _ := items[0]["candidate_order"].([]providerRuntimeCandidate) + if len(snapshotOrder) == 0 || snapshotOrder[0].Kind != "oauth" { + t.Fatalf("expected oauth-first candidate order, got %#v", items[0]["candidate_order"]) + } +} diff --git a/webui/src/App.tsx b/webui/src/App.tsx index 26db056..4e86e68 100644 --- a/webui/src/App.tsx +++ b/webui/src/App.tsx @@ -7,6 +7,7 @@ import Layout from './components/Layout'; const Dashboard = lazy(() => import('./pages/Dashboard')); const Chat = lazy(() => import('./pages/Chat')); const Config = lazy(() => import('./pages/Config')); +const Providers = lazy(() => import('./pages/Providers')); const Cron = lazy(() => import('./pages/Cron')); const Logs = lazy(() => import('./pages/Logs')); const Skills = lazy(() => import('./pages/Skills')); @@ -44,6 +45,7 @@ export default function App() { } /> } /> } /> + } /> } /> } /> } /> diff --git a/webui/src/components/FormControls.tsx b/webui/src/components/FormControls.tsx new file mode 100644 index 0000000..9513382 --- /dev/null +++ b/webui/src/components/FormControls.tsx @@ -0,0 +1,115 @@ +import React from 'react'; + +function joinClasses(...values: Array) { + return values.filter(Boolean).join(' '); +} + +type TextFieldProps = Omit, 'className'> & { + dense?: boolean; + monospace?: boolean; + className?: string; +}; + +type SelectFieldProps = Omit, 'className'> & { + dense?: boolean; + className?: string; +}; + +type TextareaFieldProps = Omit, 'className'> & { + dense?: boolean; + monospace?: boolean; + className?: string; +}; + +type CheckboxFieldProps = Omit, 'type' | 'className'> & { + className?: string; +}; + +type FieldBlockProps = { + label?: React.ReactNode; + help?: React.ReactNode; + meta?: React.ReactNode; + className?: string; + children: React.ReactNode; +}; + +type PanelFieldProps = FieldBlockProps & { + dense?: boolean; +}; + +export function TextField({ dense = false, monospace = false, className, ...props }: TextFieldProps) { + return ( + + ); +} + +export function SelectField({ dense = false, className, children, ...props }: SelectFieldProps) { + return ( + + ); +} + +export function TextareaField({ dense = false, monospace = false, className, ...props }: TextareaFieldProps) { + return ( +