mirror of
https://github.com/YspCoder/clawgo.git
synced 2026-04-14 19:37:31 +08:00
Add OAuth provider runtime and providers UI
This commit is contained in:
@@ -2,15 +2,18 @@ package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/YspCoder/clawgo/pkg/config"
|
||||
"github.com/YspCoder/clawgo/pkg/configops"
|
||||
"github.com/YspCoder/clawgo/pkg/providers"
|
||||
)
|
||||
|
||||
func configCmd() {
|
||||
@@ -172,6 +175,16 @@ func configCheckCmd() {
|
||||
}
|
||||
|
||||
func providerCmd() {
|
||||
if len(os.Args) >= 3 {
|
||||
switch strings.TrimSpace(os.Args[2]) {
|
||||
case "login":
|
||||
providerLoginCmd()
|
||||
return
|
||||
case "configure":
|
||||
// Continue into the interactive editor below.
|
||||
}
|
||||
}
|
||||
|
||||
cfg, err := loadConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading config: %v\n", err)
|
||||
@@ -222,10 +235,19 @@ func providerCmd() {
|
||||
if models := parseCSV(modelsRaw); len(models) > 0 {
|
||||
pc.Models = models
|
||||
}
|
||||
pc.Auth = promptLine(reader, "auth (bearer/oauth/none)", pc.Auth)
|
||||
pc.Auth = promptLine(reader, "auth (bearer/oauth/hybrid/none)", pc.Auth)
|
||||
timeoutRaw := promptLine(reader, "timeout_sec", fmt.Sprintf("%d", pc.TimeoutSec))
|
||||
pc.TimeoutSec = parseIntOrDefault(timeoutRaw, pc.TimeoutSec)
|
||||
pc.SupportsResponsesCompact = promptBool(reader, "supports_responses_compact", pc.SupportsResponsesCompact)
|
||||
if strings.EqualFold(strings.TrimSpace(pc.Auth), "oauth") || strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") {
|
||||
pc.OAuth.Provider = promptLine(reader, "oauth.provider", firstNonEmptyString(pc.OAuth.Provider, "codex"))
|
||||
pc.OAuth.CredentialFile = promptLine(reader, "oauth.credential_file", pc.OAuth.CredentialFile)
|
||||
pc.OAuth.CallbackPort = parseIntOrDefault(promptLine(reader, "oauth.callback_port", fmt.Sprintf("%d", defaultInt(pc.OAuth.CallbackPort, 1455))), defaultInt(pc.OAuth.CallbackPort, 1455))
|
||||
if strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") {
|
||||
pc.OAuth.HybridPriority = promptLine(reader, "oauth.hybrid_priority (api_first/oauth_first)", firstNonEmptyString(pc.OAuth.HybridPriority, "api_first"))
|
||||
}
|
||||
pc.OAuth.CooldownSec = parseIntOrDefault(promptLine(reader, "oauth.cooldown_sec", fmt.Sprintf("%d", defaultInt(pc.OAuth.CooldownSec, 900))), defaultInt(pc.OAuth.CooldownSec, 900))
|
||||
}
|
||||
|
||||
setProviderConfigByName(cfg, providerName, pc)
|
||||
|
||||
@@ -278,6 +300,119 @@ func providerCmd() {
|
||||
fmt.Println("鉁?Gateway hot reload signal sent")
|
||||
}
|
||||
|
||||
func providerLoginCmd() {
|
||||
cfg, err := loadConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading config: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
providerName := strings.TrimSpace(cfg.Agents.Defaults.Proxy)
|
||||
if providerName == "" {
|
||||
providerName = "proxy"
|
||||
}
|
||||
manual := false
|
||||
noBrowser := false
|
||||
accountLabel := ""
|
||||
for i := 3; i < len(os.Args); i++ {
|
||||
arg := strings.TrimSpace(os.Args[i])
|
||||
switch arg {
|
||||
case "--manual":
|
||||
manual = true
|
||||
case "--no-browser":
|
||||
noBrowser = true
|
||||
case "--label":
|
||||
if i+1 < len(os.Args) {
|
||||
i++
|
||||
accountLabel = strings.TrimSpace(os.Args[i])
|
||||
}
|
||||
case "":
|
||||
default:
|
||||
if strings.HasPrefix(arg, "--label=") {
|
||||
accountLabel = strings.TrimSpace(strings.TrimPrefix(arg, "--label="))
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(arg, "-") {
|
||||
providerName = arg
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pc := providerConfigByName(cfg, providerName)
|
||||
if !strings.EqualFold(strings.TrimSpace(pc.Auth), "oauth") && !strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") {
|
||||
fmt.Printf("Provider %s is not configured with auth=oauth/hybrid\n", providerName)
|
||||
os.Exit(1)
|
||||
}
|
||||
if manual {
|
||||
noBrowser = true
|
||||
}
|
||||
if manual && strings.TrimSpace(pc.OAuth.RedirectURL) == "" && pc.OAuth.CallbackPort <= 0 {
|
||||
pc.OAuth.CallbackPort = 1455
|
||||
}
|
||||
|
||||
timeout := pc.TimeoutSec
|
||||
if timeout <= 0 {
|
||||
timeout = 90
|
||||
}
|
||||
oauth, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second)
|
||||
if err != nil {
|
||||
fmt.Printf("Error preparing oauth login: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
session, models, err := oauth.Login(ctx, pc.APIBase, providers.OAuthLoginOptions{
|
||||
Manual: manual,
|
||||
NoBrowser: noBrowser,
|
||||
Reader: os.Stdin,
|
||||
AccountLabel: accountLabel,
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Printf("OAuth login failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(models) > 0 {
|
||||
pc.Models = models
|
||||
}
|
||||
if session.CredentialFile != "" {
|
||||
pc.OAuth.CredentialFile = session.CredentialFile
|
||||
pc.OAuth.CredentialFiles = appendUniqueCSV(pc.OAuth.CredentialFiles, session.CredentialFile)
|
||||
} else if pc.OAuth.CredentialFile == "" {
|
||||
pc.OAuth.CredentialFile = oauth.CredentialFile()
|
||||
pc.OAuth.CredentialFiles = appendUniqueCSV(pc.OAuth.CredentialFiles, pc.OAuth.CredentialFile)
|
||||
}
|
||||
setProviderConfigByName(cfg, providerName, pc)
|
||||
|
||||
if err := config.SaveConfig(getConfigPath(), cfg); err != nil {
|
||||
fmt.Printf("Error saving config: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("OAuth login succeeded for provider %s\n", providerName)
|
||||
if manual {
|
||||
fmt.Println("Mode: manual callback URL paste")
|
||||
} else if noBrowser {
|
||||
fmt.Println("Mode: local callback listener without auto-opening browser")
|
||||
}
|
||||
if session.Email != "" {
|
||||
fmt.Printf("Account: %s\n", session.Email)
|
||||
}
|
||||
fmt.Printf("Credential file: %s\n", firstNonEmptyString(session.CredentialFile, oauth.CredentialFile()))
|
||||
if len(pc.OAuth.CredentialFiles) > 1 {
|
||||
fmt.Printf("OAuth accounts: %d\n", len(pc.OAuth.CredentialFiles))
|
||||
}
|
||||
if len(models) > 0 {
|
||||
fmt.Printf("Discovered models: %s\n", strings.Join(models, ", "))
|
||||
}
|
||||
if running, reloadErr := triggerGatewayReload(); reloadErr == nil {
|
||||
fmt.Println("Gateway hot reload signal sent")
|
||||
} else if running {
|
||||
fmt.Printf("Hot reload not applied: %v\n", reloadErr)
|
||||
}
|
||||
}
|
||||
|
||||
func providerNames(cfg *config.Config) []string {
|
||||
names := []string{"proxy"}
|
||||
for k := range cfg.Providers.Proxies {
|
||||
@@ -370,6 +505,35 @@ func parseCSV(raw string) []string {
|
||||
return out
|
||||
}
|
||||
|
||||
func defaultInt(value, fallback int) int {
|
||||
if value > 0 {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func firstNonEmptyString(values ...string) string {
|
||||
for _, value := range values {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func appendUniqueCSV(values []string, value string) []string {
|
||||
trimmed := strings.TrimSpace(value)
|
||||
if trimmed == "" {
|
||||
return values
|
||||
}
|
||||
for _, item := range values {
|
||||
if strings.TrimSpace(item) == trimmed {
|
||||
return values
|
||||
}
|
||||
}
|
||||
return append(values, trimmed)
|
||||
}
|
||||
|
||||
func parseIntOrDefault(raw string, def int) int {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
|
||||
Reference in New Issue
Block a user