Add OAuth provider runtime and providers UI

This commit is contained in:
lpf
2026-03-11 15:47:49 +08:00
parent d9872c3da7
commit 1c0e463d07
52 changed files with 9772 additions and 901 deletions

View File

@@ -31,6 +31,7 @@ import (
"github.com/YspCoder/clawgo/pkg/channels"
cfgpkg "github.com/YspCoder/clawgo/pkg/config"
"github.com/YspCoder/clawgo/pkg/nodes"
"github.com/YspCoder/clawgo/pkg/providers"
"github.com/YspCoder/clawgo/pkg/tools"
"github.com/gorilla/websocket"
"rsc.io/qr"
@@ -73,6 +74,8 @@ type Server struct {
liveSubagents map[string]*liveSubagentGroup
whatsAppBridge *channels.WhatsAppBridgeService
whatsAppBase string
oauthFlowMu sync.Mutex
oauthFlows map[string]*providers.OAuthPendingFlow
}
var nodesWebsocketUpgrader = websocket.Upgrader{
@@ -96,6 +99,7 @@ func NewServer(host string, port int, token string, mgr *nodes.Manager) *Server
artifactStats: map[string]interface{}{},
liveRuntimeSubs: map[chan []byte]struct{}{},
liveSubagents: map[string]*liveSubagentGroup{},
oauthFlows: map[string]*providers.OAuthPendingFlow{},
}
}
@@ -449,6 +453,12 @@ func (s *Server) Start(ctx context.Context) error {
mux.HandleFunc("/webui/api/chat/live", s.handleWebUIChatLive)
mux.HandleFunc("/webui/api/runtime", s.handleWebUIRuntime)
mux.HandleFunc("/webui/api/version", s.handleWebUIVersion)
mux.HandleFunc("/webui/api/provider/oauth/start", s.handleWebUIProviderOAuthStart)
mux.HandleFunc("/webui/api/provider/oauth/complete", s.handleWebUIProviderOAuthComplete)
mux.HandleFunc("/webui/api/provider/oauth/import", s.handleWebUIProviderOAuthImport)
mux.HandleFunc("/webui/api/provider/oauth/accounts", s.handleWebUIProviderOAuthAccounts)
mux.HandleFunc("/webui/api/provider/runtime", s.handleWebUIProviderRuntime)
mux.HandleFunc("/webui/api/provider/runtime/summary", s.handleWebUIProviderRuntimeSummary)
mux.HandleFunc("/webui/api/whatsapp/status", s.handleWebUIWhatsAppStatus)
mux.HandleFunc("/webui/api/whatsapp/logout", s.handleWebUIWhatsAppLogout)
mux.HandleFunc("/webui/api/whatsapp/qr.svg", s.handleWebUIWhatsAppQR)
@@ -979,6 +989,499 @@ func (s *Server) handleWebUIUpload(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "path": path, "name": h.Filename})
}
func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Request) {
if !s.checkAuth(r) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if r.Method != http.MethodPost && r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var body struct {
Provider string `json:"provider"`
AccountLabel string `json:"account_label"`
ProviderConfig cfgpkg.ProviderConfig `json:"provider_config"`
}
if r.Method == http.MethodPost {
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
} else {
body.Provider = strings.TrimSpace(r.URL.Query().Get("provider"))
body.AccountLabel = strings.TrimSpace(r.URL.Query().Get("account_label"))
}
cfg, pc, err := s.resolveProviderConfig(strings.TrimSpace(body.Provider), body.ProviderConfig)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
_ = cfg
timeout := pc.TimeoutSec
if timeout <= 0 {
timeout = 90
}
loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
flow, err := loginMgr.StartManualFlow()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
flowID := fmt.Sprintf("%d", time.Now().UnixNano())
s.oauthFlowMu.Lock()
s.oauthFlows[flowID] = flow
s.oauthFlowMu.Unlock()
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"ok": true,
"flow_id": flowID,
"mode": flow.Mode,
"auth_url": flow.AuthURL,
"user_code": flow.UserCode,
"instructions": flow.Instructions,
"account_label": strings.TrimSpace(body.AccountLabel),
})
}
func (s *Server) handleWebUIProviderOAuthComplete(w http.ResponseWriter, r *http.Request) {
if !s.checkAuth(r) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var body struct {
Provider string `json:"provider"`
FlowID string `json:"flow_id"`
CallbackURL string `json:"callback_url"`
AccountLabel string `json:"account_label"`
ProviderConfig cfgpkg.ProviderConfig `json:"provider_config"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
cfg, pc, err := s.resolveProviderConfig(strings.TrimSpace(body.Provider), body.ProviderConfig)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
timeout := pc.TimeoutSec
if timeout <= 0 {
timeout = 90
}
loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
s.oauthFlowMu.Lock()
flow := s.oauthFlows[strings.TrimSpace(body.FlowID)]
delete(s.oauthFlows, strings.TrimSpace(body.FlowID))
s.oauthFlowMu.Unlock()
if flow == nil {
http.Error(w, "oauth flow not found", http.StatusBadRequest)
return
}
session, models, err := loginMgr.CompleteManualFlowWithOptions(r.Context(), pc.APIBase, flow, body.CallbackURL, providers.OAuthLoginOptions{
AccountLabel: strings.TrimSpace(body.AccountLabel),
})
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if len(models) > 0 {
pc.Models = models
}
if session.CredentialFile != "" {
pc.OAuth.CredentialFile = session.CredentialFile
pc.OAuth.CredentialFiles = appendUniqueStrings(pc.OAuth.CredentialFiles, session.CredentialFile)
}
if err := s.saveProviderConfig(cfg, body.Provider, pc); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"ok": true,
"account": session.Email,
"credential_file": session.CredentialFile,
"models": models,
})
}
func (s *Server) handleWebUIProviderOAuthImport(w http.ResponseWriter, r *http.Request) {
if !s.checkAuth(r) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if err := r.ParseMultipartForm(16 << 20); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
providerName := strings.TrimSpace(r.FormValue("provider"))
accountLabel := strings.TrimSpace(r.FormValue("account_label"))
inlineCfgRaw := strings.TrimSpace(r.FormValue("provider_config"))
var inlineCfg cfgpkg.ProviderConfig
if inlineCfgRaw != "" {
if err := json.Unmarshal([]byte(inlineCfgRaw), &inlineCfg); err != nil {
http.Error(w, "invalid provider_config", http.StatusBadRequest)
return
}
}
cfg, pc, err := s.resolveProviderConfig(providerName, inlineCfg)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
file, header, err := r.FormFile("file")
if err != nil {
http.Error(w, "file required", http.StatusBadRequest)
return
}
defer file.Close()
data, err := io.ReadAll(file)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
timeout := pc.TimeoutSec
if timeout <= 0 {
timeout = 90
}
loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
session, models, err := loginMgr.ImportAuthJSONWithOptions(r.Context(), pc.APIBase, header.Filename, data, providers.OAuthLoginOptions{
AccountLabel: accountLabel,
})
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if len(models) > 0 {
pc.Models = models
}
if session.CredentialFile != "" {
pc.OAuth.CredentialFile = session.CredentialFile
pc.OAuth.CredentialFiles = appendUniqueStrings(pc.OAuth.CredentialFiles, session.CredentialFile)
}
if err := s.saveProviderConfig(cfg, providerName, pc); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"ok": true,
"account": session.Email,
"credential_file": session.CredentialFile,
"models": models,
})
}
func (s *Server) handleWebUIProviderOAuthAccounts(w http.ResponseWriter, r *http.Request) {
if !s.checkAuth(r) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
providerName := strings.TrimSpace(r.URL.Query().Get("provider"))
cfg, pc, err := s.loadProviderConfig(providerName)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
_ = cfg
timeout := pc.TimeoutSec
if timeout <= 0 {
timeout = 90
}
loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
switch r.Method {
case http.MethodGet:
accounts, err := loginMgr.ListAccounts()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "accounts": accounts})
case http.MethodPost:
var body struct {
Action string `json:"action"`
CredentialFile string `json:"credential_file"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
switch strings.ToLower(strings.TrimSpace(body.Action)) {
case "refresh":
account, err := loginMgr.RefreshAccount(r.Context(), body.CredentialFile)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
_ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "account": account})
case "delete":
if err := loginMgr.DeleteAccount(body.CredentialFile); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
pc.OAuth.CredentialFiles = removeStringItem(pc.OAuth.CredentialFiles, body.CredentialFile)
if strings.TrimSpace(pc.OAuth.CredentialFile) == strings.TrimSpace(body.CredentialFile) {
pc.OAuth.CredentialFile = ""
if len(pc.OAuth.CredentialFiles) > 0 {
pc.OAuth.CredentialFile = pc.OAuth.CredentialFiles[0]
}
}
if err := s.saveProviderConfig(cfg, providerName, pc); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "deleted": true})
case "clear_cooldown":
if err := loginMgr.ClearCooldown(body.CredentialFile); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
_ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "cleared": true})
default:
http.Error(w, "unsupported action", http.StatusBadRequest)
}
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
func (s *Server) handleWebUIProviderRuntime(w http.ResponseWriter, r *http.Request) {
if !s.checkAuth(r) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if r.Method == http.MethodGet {
cfg, err := cfgpkg.LoadConfig(s.configPath)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
query := providers.ProviderRuntimeQuery{
Provider: strings.TrimSpace(r.URL.Query().Get("provider")),
EventKind: strings.TrimSpace(r.URL.Query().Get("kind")),
Reason: strings.TrimSpace(r.URL.Query().Get("reason")),
Target: strings.TrimSpace(r.URL.Query().Get("target")),
Sort: strings.TrimSpace(r.URL.Query().Get("sort")),
ChangesOnly: strings.EqualFold(strings.TrimSpace(r.URL.Query().Get("changes_only")), "true"),
}
if secs, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("window_sec"))); secs > 0 {
query.Window = time.Duration(secs) * time.Second
}
if limit, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("limit"))); limit > 0 {
query.Limit = limit
}
if cursor, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("cursor"))); cursor >= 0 {
query.Cursor = cursor
}
if healthBelow, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("health_below"))); healthBelow > 0 {
query.HealthBelow = healthBelow
}
if secs, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("cooldown_until_before_sec"))); secs > 0 {
query.CooldownBefore = time.Now().Add(time.Duration(secs) * time.Second)
}
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"ok": true,
"view": providers.GetProviderRuntimeView(cfg, query),
})
return
}
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var body struct {
Provider string `json:"provider"`
Action string `json:"action"`
OnlyExpiring bool `json:"only_expiring"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
switch strings.ToLower(strings.TrimSpace(body.Action)) {
case "clear_api_cooldown":
providers.ClearProviderAPICooldown(strings.TrimSpace(body.Provider))
_ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "cleared": true})
case "clear_history":
providers.ClearProviderRuntimeHistory(strings.TrimSpace(body.Provider))
_ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "cleared": true})
case "refresh_now":
cfg, err := cfgpkg.LoadConfig(s.configPath)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
result, err := providers.RefreshProviderRuntimeNow(cfg, strings.TrimSpace(body.Provider), body.OnlyExpiring)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
_ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "refreshed": true, "result": result})
case "rerank":
cfg, err := cfgpkg.LoadConfig(s.configPath)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
order, err := providers.RerankProviderRuntime(cfg, strings.TrimSpace(body.Provider))
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
_ = json.NewEncoder(w).Encode(map[string]interface{}{"ok": true, "reranked": true, "candidate_order": order})
default:
http.Error(w, "unsupported action", http.StatusBadRequest)
}
}
func (s *Server) handleWebUIProviderRuntimeSummary(w http.ResponseWriter, r *http.Request) {
if !s.checkAuth(r) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
cfg, err := cfgpkg.LoadConfig(s.configPath)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
query := providers.ProviderRuntimeQuery{
Provider: strings.TrimSpace(r.URL.Query().Get("provider")),
Reason: strings.TrimSpace(r.URL.Query().Get("reason")),
Target: strings.TrimSpace(r.URL.Query().Get("target")),
}
if secs, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("window_sec"))); secs > 0 {
query.Window = time.Duration(secs) * time.Second
}
if healthBelow, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("health_below"))); healthBelow > 0 {
query.HealthBelow = healthBelow
}
if query.HealthBelow <= 0 {
query.HealthBelow = 50
}
if secs, _ := strconv.Atoi(strings.TrimSpace(r.URL.Query().Get("cooldown_until_before_sec"))); secs > 0 {
query.CooldownBefore = time.Now().Add(time.Duration(secs) * time.Second)
}
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"ok": true,
"summary": providers.GetProviderRuntimeSummary(cfg, query),
})
}
func (s *Server) loadProviderConfig(name string) (*cfgpkg.Config, cfgpkg.ProviderConfig, error) {
if strings.TrimSpace(s.configPath) == "" {
return nil, cfgpkg.ProviderConfig{}, fmt.Errorf("config path not set")
}
cfg, err := cfgpkg.LoadConfig(s.configPath)
if err != nil {
return nil, cfgpkg.ProviderConfig{}, err
}
providerName := strings.TrimSpace(name)
if providerName == "" || providerName == "proxy" {
return cfg, cfg.Providers.Proxy, nil
}
pc, ok := cfg.Providers.Proxies[providerName]
if !ok {
return nil, cfgpkg.ProviderConfig{}, fmt.Errorf("provider %q not found", providerName)
}
return cfg, pc, nil
}
func (s *Server) resolveProviderConfig(name string, inline cfgpkg.ProviderConfig) (*cfgpkg.Config, cfgpkg.ProviderConfig, error) {
if hasInlineProviderConfig(inline) {
cfg, err := cfgpkg.LoadConfig(s.configPath)
if err != nil {
return nil, cfgpkg.ProviderConfig{}, err
}
return cfg, inline, nil
}
return s.loadProviderConfig(name)
}
func hasInlineProviderConfig(pc cfgpkg.ProviderConfig) bool {
return strings.TrimSpace(pc.APIBase) != "" ||
strings.TrimSpace(pc.APIKey) != "" ||
len(pc.Models) > 0 ||
strings.TrimSpace(pc.Auth) != "" ||
strings.TrimSpace(pc.OAuth.Provider) != ""
}
func (s *Server) saveProviderConfig(cfg *cfgpkg.Config, name string, pc cfgpkg.ProviderConfig) error {
if cfg == nil {
return fmt.Errorf("config is nil")
}
providerName := strings.TrimSpace(name)
if providerName == "" || providerName == "proxy" {
cfg.Providers.Proxy = pc
} else {
if cfg.Providers.Proxies == nil {
cfg.Providers.Proxies = map[string]cfgpkg.ProviderConfig{}
}
cfg.Providers.Proxies[providerName] = pc
}
if err := cfgpkg.SaveConfig(s.configPath, cfg); err != nil {
return err
}
if s.onConfigAfter != nil {
s.onConfigAfter()
} else {
_ = requestSelfReloadSignal()
}
return nil
}
func appendUniqueStrings(values []string, item string) []string {
item = strings.TrimSpace(item)
if item == "" {
return values
}
for _, value := range values {
if strings.TrimSpace(value) == item {
return values
}
}
return append(values, item)
}
func removeStringItem(values []string, item string) []string {
item = strings.TrimSpace(item)
if item == "" {
return values
}
out := make([]string, 0, len(values))
for _, value := range values {
if strings.TrimSpace(value) == item {
continue
}
out = append(out, value)
}
return out
}
func (s *Server) handleWebUIChat(w http.ResponseWriter, r *http.Request) {
if !s.checkAuth(r) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
@@ -1485,6 +1988,15 @@ func (s *Server) handleWebUIRuntime(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) buildWebUIRuntimeSnapshot(ctx context.Context) map[string]interface{} {
var providerPayload map[string]interface{}
if strings.TrimSpace(s.configPath) != "" {
if cfg, err := cfgpkg.LoadConfig(strings.TrimSpace(s.configPath)); err == nil {
providerPayload = providers.GetProviderRuntimeSnapshot(cfg)
}
}
if providerPayload == nil {
providerPayload = map[string]interface{}{"items": []interface{}{}}
}
return map[string]interface{}{
"version": s.webUIVersionPayload(),
"nodes": s.webUINodesPayload(ctx),
@@ -1492,6 +2004,7 @@ func (s *Server) buildWebUIRuntimeSnapshot(ctx context.Context) map[string]inter
"task_queue": s.webUITaskQueuePayload(false),
"ekg": s.webUIEKGSummaryPayload("24h"),
"subagents": s.webUISubagentsRuntimePayload(ctx),
"providers": providerPayload,
}
}