Add OAuth provider runtime and providers UI

This commit is contained in:
lpf
2026-03-11 15:47:49 +08:00
parent d9872c3da7
commit 1c0e463d07
52 changed files with 9772 additions and 901 deletions

View File

@@ -2,15 +2,18 @@ package main
import (
"bufio"
"context"
"encoding/json"
"fmt"
"os"
"sort"
"strconv"
"strings"
"time"
"github.com/YspCoder/clawgo/pkg/config"
"github.com/YspCoder/clawgo/pkg/configops"
"github.com/YspCoder/clawgo/pkg/providers"
)
func configCmd() {
@@ -172,6 +175,16 @@ func configCheckCmd() {
}
func providerCmd() {
if len(os.Args) >= 3 {
switch strings.TrimSpace(os.Args[2]) {
case "login":
providerLoginCmd()
return
case "configure":
// Continue into the interactive editor below.
}
}
cfg, err := loadConfig()
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
@@ -222,10 +235,19 @@ func providerCmd() {
if models := parseCSV(modelsRaw); len(models) > 0 {
pc.Models = models
}
pc.Auth = promptLine(reader, "auth (bearer/oauth/none)", pc.Auth)
pc.Auth = promptLine(reader, "auth (bearer/oauth/hybrid/none)", pc.Auth)
timeoutRaw := promptLine(reader, "timeout_sec", fmt.Sprintf("%d", pc.TimeoutSec))
pc.TimeoutSec = parseIntOrDefault(timeoutRaw, pc.TimeoutSec)
pc.SupportsResponsesCompact = promptBool(reader, "supports_responses_compact", pc.SupportsResponsesCompact)
if strings.EqualFold(strings.TrimSpace(pc.Auth), "oauth") || strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") {
pc.OAuth.Provider = promptLine(reader, "oauth.provider", firstNonEmptyString(pc.OAuth.Provider, "codex"))
pc.OAuth.CredentialFile = promptLine(reader, "oauth.credential_file", pc.OAuth.CredentialFile)
pc.OAuth.CallbackPort = parseIntOrDefault(promptLine(reader, "oauth.callback_port", fmt.Sprintf("%d", defaultInt(pc.OAuth.CallbackPort, 1455))), defaultInt(pc.OAuth.CallbackPort, 1455))
if strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") {
pc.OAuth.HybridPriority = promptLine(reader, "oauth.hybrid_priority (api_first/oauth_first)", firstNonEmptyString(pc.OAuth.HybridPriority, "api_first"))
}
pc.OAuth.CooldownSec = parseIntOrDefault(promptLine(reader, "oauth.cooldown_sec", fmt.Sprintf("%d", defaultInt(pc.OAuth.CooldownSec, 900))), defaultInt(pc.OAuth.CooldownSec, 900))
}
setProviderConfigByName(cfg, providerName, pc)
@@ -278,6 +300,119 @@ func providerCmd() {
fmt.Println("鉁?Gateway hot reload signal sent")
}
func providerLoginCmd() {
cfg, err := loadConfig()
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
os.Exit(1)
}
providerName := strings.TrimSpace(cfg.Agents.Defaults.Proxy)
if providerName == "" {
providerName = "proxy"
}
manual := false
noBrowser := false
accountLabel := ""
for i := 3; i < len(os.Args); i++ {
arg := strings.TrimSpace(os.Args[i])
switch arg {
case "--manual":
manual = true
case "--no-browser":
noBrowser = true
case "--label":
if i+1 < len(os.Args) {
i++
accountLabel = strings.TrimSpace(os.Args[i])
}
case "":
default:
if strings.HasPrefix(arg, "--label=") {
accountLabel = strings.TrimSpace(strings.TrimPrefix(arg, "--label="))
continue
}
if !strings.HasPrefix(arg, "-") {
providerName = arg
}
}
}
pc := providerConfigByName(cfg, providerName)
if !strings.EqualFold(strings.TrimSpace(pc.Auth), "oauth") && !strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") {
fmt.Printf("Provider %s is not configured with auth=oauth/hybrid\n", providerName)
os.Exit(1)
}
if manual {
noBrowser = true
}
if manual && strings.TrimSpace(pc.OAuth.RedirectURL) == "" && pc.OAuth.CallbackPort <= 0 {
pc.OAuth.CallbackPort = 1455
}
timeout := pc.TimeoutSec
if timeout <= 0 {
timeout = 90
}
oauth, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second)
if err != nil {
fmt.Printf("Error preparing oauth login: %v\n", err)
os.Exit(1)
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
session, models, err := oauth.Login(ctx, pc.APIBase, providers.OAuthLoginOptions{
Manual: manual,
NoBrowser: noBrowser,
Reader: os.Stdin,
AccountLabel: accountLabel,
})
if err != nil {
fmt.Printf("OAuth login failed: %v\n", err)
os.Exit(1)
}
if len(models) > 0 {
pc.Models = models
}
if session.CredentialFile != "" {
pc.OAuth.CredentialFile = session.CredentialFile
pc.OAuth.CredentialFiles = appendUniqueCSV(pc.OAuth.CredentialFiles, session.CredentialFile)
} else if pc.OAuth.CredentialFile == "" {
pc.OAuth.CredentialFile = oauth.CredentialFile()
pc.OAuth.CredentialFiles = appendUniqueCSV(pc.OAuth.CredentialFiles, pc.OAuth.CredentialFile)
}
setProviderConfigByName(cfg, providerName, pc)
if err := config.SaveConfig(getConfigPath(), cfg); err != nil {
fmt.Printf("Error saving config: %v\n", err)
os.Exit(1)
}
fmt.Printf("OAuth login succeeded for provider %s\n", providerName)
if manual {
fmt.Println("Mode: manual callback URL paste")
} else if noBrowser {
fmt.Println("Mode: local callback listener without auto-opening browser")
}
if session.Email != "" {
fmt.Printf("Account: %s\n", session.Email)
}
fmt.Printf("Credential file: %s\n", firstNonEmptyString(session.CredentialFile, oauth.CredentialFile()))
if len(pc.OAuth.CredentialFiles) > 1 {
fmt.Printf("OAuth accounts: %d\n", len(pc.OAuth.CredentialFiles))
}
if len(models) > 0 {
fmt.Printf("Discovered models: %s\n", strings.Join(models, ", "))
}
if running, reloadErr := triggerGatewayReload(); reloadErr == nil {
fmt.Println("Gateway hot reload signal sent")
} else if running {
fmt.Printf("Hot reload not applied: %v\n", reloadErr)
}
}
func providerNames(cfg *config.Config) []string {
names := []string{"proxy"}
for k := range cfg.Providers.Proxies {
@@ -370,6 +505,35 @@ func parseCSV(raw string) []string {
return out
}
func defaultInt(value, fallback int) int {
if value > 0 {
return value
}
return fallback
}
func firstNonEmptyString(values ...string) string {
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}
func appendUniqueCSV(values []string, value string) []string {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
return values
}
for _, item := range values {
if strings.TrimSpace(item) == trimmed {
return values
}
}
return append(values, trimmed)
}
func parseIntOrDefault(raw string, def int) int {
raw = strings.TrimSpace(raw)
if raw == "" {